OTA_TextAligner / ddp_train.py
Respair's picture
Update ddp_train.py
b2819f3 verified
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm import tqdm
import os.path as osp
import re
import sys
import yaml
import shutil
from utils import *
from optimizers import build_optimizer
from model import *
from meldataset import build_dataloader
from utils import *
from torch.utils.tensorboard import SummaryWriter
import click
from accelerate import Accelerator
from accelerate.utils import LoggerType
from accelerate import DistributedDataParallelKwargs
import logging
from logging import StreamHandler
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = StreamHandler()
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)
import logging
from accelerate.logging import get_logger
logger = get_logger(__name__, log_level="DEBUG")
# torch.autograd.detect_anomaly(True)
torch.backends.cudnn.benchmark = True
def log_print(message, logger):
logger.info(message)
print(message)
@click.command()
@click.option('-p', '--config_path', default='./Configs/config.yml', type=str)
def main(config_path):
config = yaml.safe_load(open(config_path))
log_dir = config['log_dir']
if not osp.exists(log_dir): os.mkdir(log_dir)
shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
writer = SummaryWriter(log_dir + "/tensorboard")
ddp_kwargs = DistributedDataParallelKwargs()
accelerator = Accelerator(project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs])
if accelerator.is_main_process:
writer = SummaryWriter(log_dir + "/tensorboard")
# write logs
file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
logger.logger.addHandler(file_handler)
epoch = config.get('epoch', 100)
save_iter = 1
batch_size = config.get('batch_size', 4)
log_interval = 10
device = accelerator.device
train_path = config.get('train_data', None)
val_path = config.get('val_data', None)
epochs = config.get('epochs', 1000)
train_list, val_list = get_data_path_list(train_path, val_path)
train_dataloader = build_dataloader(train_list,
batch_size=batch_size,
num_workers=8,
dataset_config=config.get('dataset_params', {}),
device=device)
val_dataloader = build_dataloader(val_list,
batch_size=batch_size,
validation=True,
num_workers=2,
device=device,
dataset_config=config.get('dataset_params', {}))
aligner = AlignerModel()
forward_sum_loss = ForwardSumLoss()
best_val_loss = float('inf')
scheduler_params = {
"max_lr": float(config['optimizer_params'].get('lr', 5e-4)),
"pct_start": float(config['optimizer_params'].get('pct_start', 0.0)),
"epochs": epochs,
"steps_per_epoch": len(train_dataloader),
}
optimizer, scheduler = build_optimizer(
{"params": aligner.parameters(), "optimizer_params":{}, "scheduler_params": scheduler_params})
aligner, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
aligner, optimizer, train_dataloader, val_dataloader, scheduler
)
with accelerator.main_process_first():
if config.get('pretrained_model', '') != '':
model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
load_only_params=config.get('load_only_params', True))
else:
start_epoch = 0
iters = 0
# Training loop
for epoch in range(1, epochs + 1):
aligner.train()
train_losses = []
train_fwd_losses = []
start_time = time.time()
# Training phase
pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}/{epochs} [Train]")
for i, batch in enumerate(pbar):
batch = [b.to(device) for b in batch]
text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch
# Forward pass
attn_soft, attn_logprob = aligner(spec=mel_input,
spec_len=mel_input_length,
text=text_input,
text_len=text_input_length,
attn_prior=attn_prior)
# Calculate loss
loss = forward_sum_loss(attn_logprob=attn_logprob,
in_lens=text_input_length,
out_lens=mel_input_length)
# Backward pass and optimization
optimizer.zero_grad()
accelerator.backward(loss)
# Optional gradient clipping
grad_norm = accelerator.clip_grad_norm_(aligner.parameters(), 5.0)
optimizer.step()
iters = iters + 1
if scheduler is not None:
scheduler.step()
if (i+1)%log_interval == 0 and accelerator.is_main_process:
log_print('Epoch [%d/%d], Step [%d/%d], Forward Sum Loss: %.5f'
%(epoch+1, epochs, i+1, len(train_list)//batch_size, loss), logger)
writer.add_scalar('train/Forward Sum Loss', loss, iters)
# writer.add_scalar('train/d_loss', d_loss, iters)
train_losses.append(loss.item())
train_fwd_losses.append(loss.item())
running_loss = 0
accelerator.print('Time elasped:', time.time()-start_time)
# Calculate average training loss for this epoch
avg_train_loss = sum(train_losses) / len(train_losses)
# Validation phase
aligner.eval()
val_losses = []
with torch.no_grad():
for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{epochs} [Val]"):
batch = [b.to(device) for b in batch]
text_input, text_input_length, mel_input, mel_input_length = batch
# Forward pass
attn_soft, attn_logprob = aligner(spec=mel_input,
spec_len=mel_input_length,
text=text_input,
text_len=text_input_length,
attn_prior=None)
# Calculate loss
val_loss = forward_sum_loss(attn_logprob=attn_logprob,
in_lens=text_input_length,
out_lens=mel_input_length)
val_losses.append(val_loss.item())
# Calculate average validation loss
avg_val_loss = sum(val_losses) / len(val_losses)
# Log to TensorBoard
writer.add_scalar('epoch/train_loss', avg_train_loss, epoch)
writer.add_scalar('epoch/val_loss', avg_val_loss, epoch)
# Save checkpoint every N epochs
if (i+1)%save_iter == 0 and accelerator.is_main_process:
print(f'Saving on step {epoch*len(train_dataloader)+i}...')
state = {
'net': {key: aligner[key].state_dict() for key in aligner},
'optimizer': optimizer.state_dict(),
'iters': iters,
'epoch': epoch,
}
save_path = os.path.join(log_dir, 'checkpoints', f'TextAligner_checkpoint_epoch_{epoch}.pt')
torch.save(state, save_path)
# Print summary for this epoch
epoch_time = time.time() - start_time
accelerator.print(f"Epoch {epoch}/{epochs} completed in {epoch_time:.2f}s | "
f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
# # Plot and save attention matrices for visualization
# if epoch % config.get('plot_every', 10) == 0:
# plot_attention_matrices(aligner, val_dataloader, device,
# os.path.join(log_dir, 'attention_plots', f'epoch_{epoch}'),
# num_samples=4)
writer.close()
if __name__=="__main__":
main()