""" Simple example usage of CosmicFish model (local model) """ import torch from transformers import GPT2Tokenizer from modeling_cosmicfish import CosmicFish, CosmicConfig from safetensors.torch import load_file import json def load_cosmicfish(model_dir): """Load CosmicFish model and tokenizer""" # Load config with open(f"{model_dir}/config.json", "r") as f: config_dict = json.load(f) # Create model config config = CosmicConfig( vocab_size=config_dict["vocab_size"], block_size=config_dict["block_size"], n_layer=config_dict["n_layer"], n_head=config_dict["n_head"], n_embd=config_dict["n_embd"], bias=config_dict["bias"], dropout=0.0, use_rotary=config_dict["use_rotary"], use_swiglu=config_dict["use_swiglu"], use_gqa=config_dict["use_gqa"], n_query_groups=config_dict["n_query_groups"], use_qk_norm=config_dict.get("use_qk_norm", False) ) # Create and load model model = CosmicFish(config) state_dict = load_file(f"{model_dir}/model.safetensors") # Handle weight sharing if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict: state_dict['lm_head.weight'] = state_dict['transformer.wte.weight'] model.load_state_dict(state_dict) model.eval() # Load tokenizer tokenizer = GPT2Tokenizer.from_pretrained("gpt2") return model, tokenizer def simple_generate(model, tokenizer, prompt, max_tokens=50, temperature=0.7): """Generate text from a prompt""" inputs = tokenizer.encode(prompt, return_tensors="pt") with torch.no_grad(): outputs = model.generate( inputs, max_new_tokens=max_tokens, temperature=temperature, top_k=40 ) return tokenizer.decode(outputs[0], skip_special_tokens=True) if __name__ == "__main__": # Load model print("Loading CosmicFish...") model, tokenizer = load_cosmicfish("./") print(f"Model loaded! ({model.get_num_params()/1e6:.1f}M parameters)") # Example prompts prompts = [ "What is climate change?", "Write a poem", "Define ML" ] # Generate responses for prompt in prompts: print(f"\nPrompt: {prompt}") response = simple_generate(model, tokenizer, prompt, max_tokens=30) print(f"Response: {response}")