CosmicFish-90M / example_usage.py
akkiisfrommars's picture
Upload example_usage.py
04146a5 verified
"""
Simple example usage of CosmicFish model (local model)
"""
import torch
from transformers import GPT2Tokenizer
from modeling_cosmicfish import CosmicFish, CosmicConfig
from safetensors.torch import load_file
import json
def load_cosmicfish(model_dir):
"""Load CosmicFish model and tokenizer"""
# Load config
with open(f"{model_dir}/config.json", "r") as f:
config_dict = json.load(f)
# Create model config
config = CosmicConfig(
vocab_size=config_dict["vocab_size"],
block_size=config_dict["block_size"],
n_layer=config_dict["n_layer"],
n_head=config_dict["n_head"],
n_embd=config_dict["n_embd"],
bias=config_dict["bias"],
dropout=0.0,
use_rotary=config_dict["use_rotary"],
use_swiglu=config_dict["use_swiglu"],
use_gqa=config_dict["use_gqa"],
n_query_groups=config_dict["n_query_groups"],
use_qk_norm=config_dict.get("use_qk_norm", False)
)
# Create and load model
model = CosmicFish(config)
state_dict = load_file(f"{model_dir}/model.safetensors")
# Handle weight sharing
if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict:
state_dict['lm_head.weight'] = state_dict['transformer.wte.weight']
model.load_state_dict(state_dict)
model.eval()
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
return model, tokenizer
def simple_generate(model, tokenizer, prompt, max_tokens=50, temperature=0.7):
"""Generate text from a prompt"""
inputs = tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=40
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
if __name__ == "__main__":
# Load model
print("Loading CosmicFish...")
model, tokenizer = load_cosmicfish("./")
print(f"Model loaded! ({model.get_num_params()/1e6:.1f}M parameters)")
# Example prompts
prompts = [
"What is climate change?",
"Write a poem",
"Define ML"
]
# Generate responses
for prompt in prompts:
print(f"\nPrompt: {prompt}")
response = simple_generate(model, tokenizer, prompt, max_tokens=30)
print(f"Response: {response}")