|
import os |
|
import math |
|
import pickle |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
|
|
import wandb |
|
|
|
from dataset import get_dataloader |
|
from model import GreyMatter |
|
|
|
|
|
|
|
def prepare_train_val_split(data_path="fineweb_subset.txt", train_percent=0.9): |
|
|
|
path = os.path.dirname(data_path) |
|
train_path = os.path.join(path, "train.txt") |
|
val_path = os.path.join(path, "val.txt") |
|
if not os.path.exists(train_path) and not os.path.exists(val_path): |
|
with open(data_path, "r", encoding="utf-8") as file: |
|
text = file.read() |
|
|
|
train_data = text[:int(len(text)*train_percent)] |
|
val_data = text[int(len(text)*train_percent):] |
|
|
|
with open(train_path, "w") as file: |
|
file.write(train_data) |
|
|
|
with open(val_path, "w") as file: |
|
file.write(val_data) |
|
print("Train Validation split complete") |
|
else: |
|
print("Train and Validation files exists, skipping split") |
|
|
|
return train_path, val_path |
|
|
|
def calculate_perplexity(loss): |
|
"""Calculate perplexity from cross-entropy loss""" |
|
return math.exp(loss) |
|
|
|
def train(model, train_loader, val_loader, config, warmup_steps=1000): |
|
|
|
device = config["device"] |
|
model = model.to(device) |
|
|
|
grad_accum_steps = config.get("grad_acc_step", 1) |
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config['weight_decay']) |
|
criterion = nn.CrossEntropyLoss(ignore_index=0) |
|
|
|
total_steps = len(train_loader) // grad_accum_steps * config["num_epochs"] |
|
|
|
|
|
def lr_lambda(step): |
|
if step < warmup_steps: |
|
return step / warmup_steps |
|
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) |
|
return 0.5 * (1.0 + math.cos(math.pi * progress)) |
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
train_losses, val_losses = [], [] |
|
train_perplexities, val_perplexities = [], [] |
|
global_step = 0 |
|
|
|
print(f"Starting training for {config['num_epochs']} epochs") |
|
print(f"Total steps: {total_steps}, Warmup steps: {warmup_steps}") |
|
print("-" * 60) |
|
|
|
for epoch in range(config["num_epochs"]): |
|
|
|
model.train() |
|
epoch_train_loss = 0 |
|
scaled_loss = 0 |
|
num_batches = 0 |
|
optimizer.zero_grad() |
|
|
|
for batch_idx, batch_tokens in enumerate(train_loader): |
|
input_ids = batch_tokens[0].to(device) |
|
targets = batch_tokens[1].to(device) |
|
logits = model(input_ids) |
|
loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) |
|
|
|
scaled_loss = loss / grad_accum_steps |
|
scaled_loss.backward() |
|
|
|
if (batch_idx + 1) % grad_accum_steps == 0: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optimizer.step() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
global_step += 1 |
|
|
|
epoch_train_loss += loss.item() |
|
num_batches += 1 |
|
|
|
|
|
if batch_idx % 50 == 0: |
|
current_perplexity = calculate_perplexity(loss.item()) |
|
current_lr = scheduler.get_last_lr()[0] |
|
print(f"Epoch {epoch+1}/{config['num_epochs']}, Step {batch_idx}, " |
|
f"Loss: {loss.item():.4f}, Perplexity: {current_perplexity:.2f}, " |
|
f"LR: {current_lr:.6f}") |
|
wandb.log({ |
|
"train/epoch": (epoch+1)/config['num_epochs'], |
|
"train/step": batch_idx, |
|
"train/loss": float(f"{loss.item():.4f}"), |
|
"train/learning_rate": float(f"{current_lr:.6f}"), |
|
"train/perplexity": float(f"{current_perplexity:.2f}") |
|
}) |
|
|
|
|
|
average_train_loss = epoch_train_loss / num_batches |
|
train_perplexity = calculate_perplexity(average_train_loss) |
|
|
|
|
|
model.eval() |
|
epoch_val_loss = 0 |
|
val_batches = 0 |
|
|
|
print("Running validation...") |
|
with torch.no_grad(): |
|
for val_tokens in val_loader: |
|
input_ids = val_tokens[0].to(device) |
|
targets = val_tokens[1].to(device) |
|
logits = model(input_ids) |
|
loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) |
|
epoch_val_loss += loss.item() |
|
val_batches += 1 |
|
|
|
average_val_loss = epoch_val_loss / val_batches |
|
val_perplexity = calculate_perplexity(average_val_loss) |
|
|
|
|
|
train_losses.append(average_train_loss) |
|
val_losses.append(average_val_loss) |
|
train_perplexities.append(train_perplexity) |
|
val_perplexities.append(val_perplexity) |
|
|
|
|
|
print("-" * 60) |
|
print(f"EPOCH {epoch+1} SUMMARY:") |
|
print(f"Train Loss: {average_train_loss:.4f} | Train Perplexity: {train_perplexity:.2f}") |
|
print(f"Val Loss: {average_val_loss:.4f} | Val Perplexity: {val_perplexity:.2f}") |
|
print(f"Global Step: {global_step}") |
|
print("-" * 60) |
|
|
|
wandb.log({ |
|
"val/epoch": epoch+1, |
|
"val/loss": float(f"{average_val_loss:.4f}"), |
|
"val/perplexity": float(f"{val_perplexity:.2f}"), |
|
"val/global_step": global_step |
|
}) |
|
|
|
if (epoch + 1) % 2 == 0: |
|
checkpoint_path = f'checkpoint_epoch_{epoch+1}.pt' |
|
torch.save({ |
|
'epoch': epoch + 1, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'scheduler_state_dict': scheduler.state_dict(), |
|
'train_loss': average_train_loss, |
|
'val_loss': average_val_loss, |
|
'train_perplexity': train_perplexity, |
|
'val_perplexity': val_perplexity, |
|
'global_step': global_step, |
|
'config': config |
|
}, checkpoint_path) |
|
print(f"Checkpoint saved: {checkpoint_path}") |
|
|
|
return { |
|
'train_losses': train_losses, |
|
'val_losses': val_losses, |
|
'train_perplexities': train_perplexities, |
|
'val_perplexities': val_perplexities |
|
} |
|
|
|
def main(): |
|
wandb.login() |
|
config = { |
|
'vocab_size': 25000, |
|
'seq_len': 1024, |
|
'd_model': 768, |
|
'n_heads': 8, |
|
'n_layers': 12, |
|
'd_ff': 4 * 768, |
|
'max_seq_len': 1024, |
|
'dropout': 0.1, |
|
'batch_size': 8, |
|
'grad_acc_step': 8, |
|
'learning_rate': 1e-4, |
|
'weight_decay': 0.01, |
|
'num_epochs': 3, |
|
'train_split': 0.8, |
|
'device': 'cuda' if torch.cuda.is_available() else 'cpu', |
|
} |
|
|
|
model = GreyMatter( |
|
d_model=config["d_model"], |
|
n_heads=config["n_heads"], |
|
vocab_size=config["vocab_size"], |
|
n_layers=config["n_layers"], |
|
d_ff=config["d_ff"], |
|
max_seq_len=config["max_seq_len"], |
|
dropout=config["dropout"] |
|
) |
|
|
|
print(model.get_parameter_count()) |
|
|
|
n_params = model.get_parameter_count()[0] |
|
wandb.init(project=f"greymatter-pretraining-{n_params}", config=config) |
|
|
|
train_path, val_path = prepare_train_val_split(data_path="fineweb_subset.txt", train_percent=config ["train_split"]) |
|
|
|
|
|
train_dataloader = get_dataloader(file_path=train_path, seq_len=config["seq_len"], batch_size=config["batch_size"]) |
|
val_dataloader = get_dataloader(file_path=val_path, seq_len=config["seq_len"], batch_size=config["batch_size"]) |
|
|
|
metrics = train(model, train_dataloader, val_dataloader, config) |
|
|
|
|
|
with open('training_metrics.pkl', 'wb') as f: |
|
pickle.dump(metrics, f) |
|
|
|
print("Training completed! Metrics saved to 'training_metrics.pkl'") |
|
|
|
main() |
|
|