|
""" |
|
Simple example usage of CosmicFish model (local model) |
|
""" |
|
import torch |
|
from transformers import GPT2Tokenizer |
|
from modeling_cosmicfish import CosmicFish, CosmicConfig |
|
from safetensors.torch import load_file |
|
import json |
|
|
|
def load_cosmicfish(model_dir): |
|
"""Load CosmicFish model and tokenizer""" |
|
|
|
with open(f"{model_dir}/config.json", "r") as f: |
|
config_dict = json.load(f) |
|
|
|
|
|
config = CosmicConfig( |
|
vocab_size=config_dict["vocab_size"], |
|
block_size=config_dict["block_size"], |
|
n_layer=config_dict["n_layer"], |
|
n_head=config_dict["n_head"], |
|
n_embd=config_dict["n_embd"], |
|
bias=config_dict["bias"], |
|
dropout=0.0, |
|
use_rotary=config_dict["use_rotary"], |
|
use_swiglu=config_dict["use_swiglu"], |
|
use_gqa=config_dict["use_gqa"], |
|
n_query_groups=config_dict["n_query_groups"], |
|
use_qk_norm=config_dict.get("use_qk_norm", False) |
|
) |
|
|
|
|
|
model = CosmicFish(config) |
|
state_dict = load_file(f"{model_dir}/model.safetensors") |
|
|
|
|
|
if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict: |
|
state_dict['lm_head.weight'] = state_dict['transformer.wte.weight'] |
|
|
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
return model, tokenizer |
|
|
|
def simple_generate(model, tokenizer, prompt, max_tokens=50, temperature=0.7): |
|
"""Generate text from a prompt""" |
|
inputs = tokenizer.encode(prompt, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_k=40 |
|
) |
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
if __name__ == "__main__": |
|
|
|
print("Loading CosmicFish...") |
|
model, tokenizer = load_cosmicfish("./") |
|
print(f"Model loaded! ({model.get_num_params()/1e6:.1f}M parameters)") |
|
|
|
|
|
prompts = [ |
|
"What is climate change?", |
|
"Write a poem", |
|
"Define ML" |
|
] |
|
|
|
|
|
for prompt in prompts: |
|
print(f"\nPrompt: {prompt}") |
|
response = simple_generate(model, tokenizer, prompt, max_tokens=30) |
|
print(f"Response: {response}") |
|
|