File size: 809 Bytes
c86a1af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
#!/bin/env python
# This script extracts the "encoder-only" part from the full t5-xl model
from transformers import T5ForConditionalGeneration, T5EncoderModel
src_model_name = "google/t5-v1_1-xl"
dst_dir = "./t5-v1_1-xl-encoder-only"
full_model = T5ForConditionalGeneration.from_pretrained(src_model_name)
# Initialize empty encoder-only model (inherits config, so tokenizer stays compatible)
encoder_model = T5EncoderModel(full_model.config)
# Get the full state dict, then ditch the parts we dont need
state_dict = full_model.state_dict()
encoder_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("decoder.") and not k.startswith("lm_head.")}
encoder_model.load_state_dict(encoder_state_dict)
encoder_model.save_pretrained(dst_dir)
print(f"Encoder-only model saved to {dst_dir}")
|