File size: 809 Bytes
c86a1af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#!/bin/env python

# This script extracts the "encoder-only" part from the full t5-xl model

from transformers import T5ForConditionalGeneration, T5EncoderModel

src_model_name = "google/t5-v1_1-xl"
dst_dir = "./t5-v1_1-xl-encoder-only"

full_model = T5ForConditionalGeneration.from_pretrained(src_model_name)
# Initialize empty encoder-only model (inherits config, so tokenizer stays compatible)
encoder_model = T5EncoderModel(full_model.config)

# Get the full state dict, then ditch the parts we dont need
state_dict = full_model.state_dict()
encoder_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("decoder.") and not k.startswith("lm_head.")}

encoder_model.load_state_dict(encoder_state_dict)

encoder_model.save_pretrained(dst_dir)
print(f"Encoder-only model saved to {dst_dir}")