deraing-restormer / train.py
Ziy's picture
a train file
bed96e6
######### import ###########
import os
from config import Config
opt = Config('training.yml') # 导出为一个类
import torch
print(torch.cuda.is_available())
gpus = ','.join([str(i) for i in opt.GPU])
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # 表示按照PCI_BUS_ID顺序从0开始排列GPU设备。environ是一个字符串所对应环境的映像对象,environ['HOME']就代表了当前这个用户的主目录
os.environ["CUDA_VISIBLE_DEVICES"] = gpus
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1," #设置当前使用的GPU设备为1,0号两个设备,名称依次为'/gpu:0'、'/gpu:1'。表示优先使用1号设备,然后使用0号设备
torch.backends.cudnn.benchmark = True # 加快网络运行速度
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import utilss
from torch.utils.data import DataLoader
import random
import time
import numpy as np
from data_RGB import get_training_data, get_validation_data
from Restormer import Restormer
import losses
from warmup_scheduler import GradualWarmupScheduler
from tqdm import tqdm
from pdb import set_trace as stx
import utils
print(torch.cuda.is_available())
######### Set Seeds ###########
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234) # 为CPU设置种子用于生成随机数,以使得结果是确定的
torch.cuda.manual_seed_all(1234) # 为当前所有的GPU设置随机种子;
start_epoch = 1
mode = opt.MODEL.MODE # deraining
session = opt.MODEL.SESSION # MPRNet
result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session) # opt.TRAINING.SAVE_DIR= './checkpoints'
model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session)
utilss.dir_utils.mkdir(result_dir)
utilss.dir_utils.mkdir(model_dir)
train_dir = opt.TRAINING.TRAIN_DIR # TRAIN_DIR: './Datasets/train'
val_dir = opt.TRAINING.VAL_DIR # VAL_DIR: './Datasets/test/Rain5H'
# factor = 8
######### Model ###########
model_restoration = Restormer()
device_ids = [i for i in range(torch.cuda.device_count())]
# device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
# model_restoration = model_restoration.to(device)
print(torch.cuda.is_available())
if torch.cuda.device_count() > 1:
print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
model_restoration.cuda() # 将模型加载到GPU上去。这种方法不被提倡,而建议使用model.to(device)的方式,这样可以显示指定需要使用的计算资源,特别是有多个GPU的情况下。
new_lr = opt.OPTIM.LR_INITIAL # LR_INITIAL: 2e-4
optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999), eps=1e-8)
######### Scheduler ###########
warmup_epochs = 3
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS - warmup_epochs,
eta_min=opt.OPTIM.LR_MIN) # torch.optim 调整学习率
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs,
after_scheduler=scheduler_cosine) # warmup_scheduler中 优化器
######### Resume ###########
if opt.TRAINING.RESUME: # RESUME: False
path_chk_rest = utilss.get_last_path(model_dir, '_latest.pth')
utilss.load_checkpoint(model_restoration, path_chk_rest)
start_epoch = utilss.load_start_epoch(path_chk_rest) + 1
utilss.load_optim(optimizer, path_chk_rest)
for i in range(1, start_epoch):
scheduler.step()
new_lr = scheduler.get_lr()[0]
print('------------------------------------------------------------------------------')
print("==> Resuming Training with learning rate:", new_lr)
print('------------------------------------------------------------------------------')
if len(device_ids) > 1: # 有多个GPU则可以进行数据并行运算
print("有多个GPU,可以进行数据并行运算")
model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids)
######### Loss ###########
criterion_char = losses.CharbonnierLoss()
criterion_edge = losses.EdgeLoss()
######### DataLoaders ###########
train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS}) # TRAIN_PS: 256
# print("train_dataset.shape: ",train_dataset.shape)
train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=8,
drop_last=False, pin_memory=True) # BATCH_SIZE: 16
# print("train_loader.shape: ",train_loader.size)
val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS}) # VAL_PS: 128
val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=4, drop_last=False,
pin_memory=True)
print('===> Start Epoch {} End Epoch {}'.format(start_epoch, opt.OPTIM.NUM_EPOCHS + 1)) # NUM_EPOCHS: 250
print('===> Loading datasets')
best_psnr = 0
best_epoch = 0
for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1):
epoch_start_time = time.time()
epoch_loss = 0
train_id = 1
model_restoration.train()
for i, data in enumerate(tqdm(train_loader), 0):
# zero_grad
for param in model_restoration.parameters():
param.grad = None
target = data[0].cuda()
input_ = data[1].cuda()
# target = data[0].to(device)
# input_ = data[1].to(device)
print("before in model,input_.shape: ",input_.shape)
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
restored = model_restoration(input_)
# Compute loss at each stage
loss_char = criterion_char(restored, target)
loss_edge = criterion_edge(restored, target)
loss = (loss_char) + (0.05 * loss_edge)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
#### Evaluation ####
# VAL_AFTER_EVERY: 5
if epoch % opt.TRAINING.VAL_AFTER_EVERY == 0:
model_restoration.eval()
psnr_val_rgb = []
for ii, data_val in enumerate((val_loader), 0):
target = data_val[0].cuda()
input_ = data_val[1].cuda()
# target = data_val[0].to(device)
# input_ = data_val[1].to(device)
restored = 0
with torch.no_grad():
restored = model_restoration(input_)
for res, tar in zip(restored, target):
psnr_val_rgb.append(utilss.torchPSNR(res, tar))
psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
if psnr_val_rgb > best_psnr:
best_psnr = psnr_val_rgb
best_epoch = epoch
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer': optimizer.state_dict()
}, os.path.join(model_dir, "model_best.pth"))
print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer': optimizer.state_dict()
}, os.path.join(model_dir, f"model_epoch_{epoch}.pth"))
scheduler.step()
print("------------------------------------------------------------------")
print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.8f}".format(epoch, time.time() - epoch_start_time,
epoch_loss, scheduler.get_lr()[0]))
print("------------------------------------------------------------------")
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer': optimizer.state_dict()
}, os.path.join(model_dir, "model_latest.pth"))