#!/bin/env python | |
# This script extracts the "encoder-only" part from the full t5-xl model | |
from transformers import T5ForConditionalGeneration, T5EncoderModel | |
src_model_name = "google/t5-v1_1-xl" | |
dst_dir = "./t5-v1_1-xl-encoder-only" | |
full_model = T5ForConditionalGeneration.from_pretrained(src_model_name) | |
# Initialize empty encoder-only model (inherits config, so tokenizer stays compatible) | |
encoder_model = T5EncoderModel(full_model.config) | |
# Get the full state dict, then ditch the parts we dont need | |
state_dict = full_model.state_dict() | |
encoder_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("decoder.") and not k.startswith("lm_head.")} | |
encoder_model.load_state_dict(encoder_state_dict) | |
encoder_model.save_pretrained(dst_dir) | |
print(f"Encoder-only model saved to {dst_dir}") | |