greymatter-123M / train.py
johnsonarokiadoss52197's picture
Upload 8 files
9ff262b verified
import os
import math
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from dataset import get_dataloader
from model import GreyMatter
# prepare dataset
# data.txt -> train.txt, val.txt
def prepare_train_val_split(data_path="fineweb_subset.txt", train_percent=0.9):
# train.txt, val.txt exists skip
path = os.path.dirname(data_path)
train_path = os.path.join(path, "train.txt")
val_path = os.path.join(path, "val.txt")
if not os.path.exists(train_path) and not os.path.exists(val_path):
with open(data_path, "r", encoding="utf-8") as file:
text = file.read()
train_data = text[:int(len(text)*train_percent)]
val_data = text[int(len(text)*train_percent):]
with open(train_path, "w") as file:
file.write(train_data)
with open(val_path, "w") as file:
file.write(val_data)
print("Train Validation split complete")
else:
print("Train and Validation files exists, skipping split")
# return train_path, val_path
return train_path, val_path
def calculate_perplexity(loss):
"""Calculate perplexity from cross-entropy loss"""
return math.exp(loss)
def train(model, train_loader, val_loader, config, warmup_steps=1000):
device = config["device"]
model = model.to(device)
grad_accum_steps = config.get("grad_acc_step", 1)
optimizer = optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config['weight_decay'])
criterion = nn.CrossEntropyLoss(ignore_index=0)
total_steps = len(train_loader) // grad_accum_steps * config["num_epochs"]
# Cosine LR scheduler with warmup
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
train_losses, val_losses = [], []
train_perplexities, val_perplexities = [], []
global_step = 0
print(f"Starting training for {config['num_epochs']} epochs")
print(f"Total steps: {total_steps}, Warmup steps: {warmup_steps}")
print("-" * 60)
for epoch in range(config["num_epochs"]):
model.train()
epoch_train_loss = 0
scaled_loss = 0
num_batches = 0
optimizer.zero_grad()
for batch_idx, batch_tokens in enumerate(train_loader):
input_ids = batch_tokens[0].to(device)
targets = batch_tokens[1].to(device)
logits = model(input_ids)
loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
scaled_loss = loss / grad_accum_steps
scaled_loss.backward()
if (batch_idx + 1) % grad_accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step() # update LR
optimizer.zero_grad()
global_step += 1
epoch_train_loss += loss.item()
num_batches += 1
# Print progress with perplexity
if batch_idx % 50 == 0:
current_perplexity = calculate_perplexity(loss.item())
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {epoch+1}/{config['num_epochs']}, Step {batch_idx}, "
f"Loss: {loss.item():.4f}, Perplexity: {current_perplexity:.2f}, "
f"LR: {current_lr:.6f}")
wandb.log({
"train/epoch": (epoch+1)/config['num_epochs'],
"train/step": batch_idx,
"train/loss": float(f"{loss.item():.4f}"),
"train/learning_rate": float(f"{current_lr:.6f}"),
"train/perplexity": float(f"{current_perplexity:.2f}")
})
average_train_loss = epoch_train_loss / num_batches
train_perplexity = calculate_perplexity(average_train_loss)
# validation
model.eval()
epoch_val_loss = 0
val_batches = 0
print("Running validation...")
with torch.no_grad():
for val_tokens in val_loader:
input_ids = val_tokens[0].to(device)
targets = val_tokens[1].to(device)
logits = model(input_ids)
loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
epoch_val_loss += loss.item()
val_batches += 1
average_val_loss = epoch_val_loss / val_batches
val_perplexity = calculate_perplexity(average_val_loss)
# Store metrics
train_losses.append(average_train_loss)
val_losses.append(average_val_loss)
train_perplexities.append(train_perplexity)
val_perplexities.append(val_perplexity)
# Print epoch summary
print("-" * 60)
print(f"EPOCH {epoch+1} SUMMARY:")
print(f"Train Loss: {average_train_loss:.4f} | Train Perplexity: {train_perplexity:.2f}")
print(f"Val Loss: {average_val_loss:.4f} | Val Perplexity: {val_perplexity:.2f}")
print(f"Global Step: {global_step}")
print("-" * 60)
wandb.log({
"val/epoch": epoch+1,
"val/loss": float(f"{average_val_loss:.4f}"),
"val/perplexity": float(f"{val_perplexity:.2f}"),
"val/global_step": global_step
})
if (epoch + 1) % 2 == 0:
checkpoint_path = f'checkpoint_epoch_{epoch+1}.pt'
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'train_loss': average_train_loss,
'val_loss': average_val_loss,
'train_perplexity': train_perplexity,
'val_perplexity': val_perplexity,
'global_step': global_step,
'config': config
}, checkpoint_path)
print(f"Checkpoint saved: {checkpoint_path}")
return {
'train_losses': train_losses,
'val_losses': val_losses,
'train_perplexities': train_perplexities,
'val_perplexities': val_perplexities
}
def main():
wandb.login()
config = {
'vocab_size': 25000,
'seq_len': 1024,
'd_model': 768,
'n_heads': 8,
'n_layers': 12,
'd_ff': 4 * 768,
'max_seq_len': 1024,
'dropout': 0.1,
'batch_size': 8,
'grad_acc_step': 8, # effective_batch_size = batch_size * grad_acc_step
'learning_rate': 1e-4,
'weight_decay': 0.01,
'num_epochs': 3,
'train_split': 0.8,
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}
model = GreyMatter(
d_model=config["d_model"],
n_heads=config["n_heads"],
vocab_size=config["vocab_size"],
n_layers=config["n_layers"],
d_ff=config["d_ff"],
max_seq_len=config["max_seq_len"],
dropout=config["dropout"]
)
print(model.get_parameter_count())
n_params = model.get_parameter_count()[0]
wandb.init(project=f"greymatter-pretraining-{n_params}", config=config)
train_path, val_path = prepare_train_val_split(data_path="fineweb_subset.txt", train_percent=config ["train_split"]) # replace with your filename
# prepare data loader
train_dataloader = get_dataloader(file_path=train_path, seq_len=config["seq_len"], batch_size=config["batch_size"])
val_dataloader = get_dataloader(file_path=val_path, seq_len=config["seq_len"], batch_size=config["batch_size"])
metrics = train(model, train_dataloader, val_dataloader, config)
# Save the metrics
with open('training_metrics.pkl', 'wb') as f:
pickle.dump(metrics, f)
print("Training completed! Metrics saved to 'training_metrics.pkl'")
main()