|
import os
|
|
import gc
|
|
import argparse
|
|
import datetime
|
|
from io import BytesIO
|
|
from glob import glob
|
|
from tqdm.auto import tqdm
|
|
from PIL import Image
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision.transforms import v2, InterpolationMode
|
|
|
|
import datasets
|
|
import bitsandbytes as bnb
|
|
from transformers import CLIPTextModel, CLIPTokenizer
|
|
from diffusers import AutoencoderKL, UNet2DConditionModel
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description = "DiT training script",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type = str,
|
|
default = "./outputs",
|
|
help = "Output directory for training results",
|
|
)
|
|
parser.add_argument(
|
|
"--unet",
|
|
type = str,
|
|
default = "./sd_flow_unet",
|
|
help = "folder for unet init",
|
|
)
|
|
parser.add_argument(
|
|
"--seed",
|
|
type = int,
|
|
default = 42,
|
|
help = "Seed for reproducible training",
|
|
)
|
|
parser.add_argument(
|
|
"--batch_size",
|
|
type = int,
|
|
default = 16,
|
|
)
|
|
parser.add_argument(
|
|
"--base_lr",
|
|
type = float,
|
|
default = 2e-6,
|
|
help = "Base learning rate, will be scaled by sqrt(batch_size)",
|
|
)
|
|
parser.add_argument(
|
|
"--shift",
|
|
type = float,
|
|
default = 2.0,
|
|
help = "Noise schedule shift for training (shift > 1 will spend more effort on early timesteps/high noise)",
|
|
)
|
|
parser.add_argument(
|
|
"--dropout",
|
|
type = float,
|
|
default = 0.1,
|
|
help = "Probability to drop out conditioning (to support CFG)",
|
|
)
|
|
parser.add_argument(
|
|
"--max_train_steps",
|
|
type = int,
|
|
default = 50_000,
|
|
help = "Total number of training steps",
|
|
)
|
|
parser.add_argument(
|
|
"--checkpointing_steps",
|
|
type = int,
|
|
default = 1000,
|
|
help = "Save a checkpoint of the training state every X steps",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def train(args):
|
|
device = "cuda"
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
torch.manual_seed(args.seed)
|
|
torch.cuda.manual_seed(args.seed)
|
|
|
|
date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
real_output_dir = os.path.join(args.output_dir, date_time)
|
|
os.makedirs(real_output_dir, exist_ok=True)
|
|
t_writer = SummaryWriter(log_dir=real_output_dir, flush_secs=60)
|
|
|
|
data_files = glob("E:/datasets/commoncatalog-cc-by/**/*.parquet", recursive=True)
|
|
train_dataset = datasets.load_dataset("parquet", data_files=data_files, split="train", streaming=True)
|
|
train_dataset = train_dataset.shuffle(seed=args.seed, buffer_size=1000)
|
|
|
|
image_transforms = v2.Compose([
|
|
v2.ToImage(),
|
|
v2.ToDtype(dtype=torch.float32, scale=True),
|
|
v2.Resize(512),
|
|
v2.CenterCrop(512),
|
|
])
|
|
|
|
def collate_fn(examples):
|
|
captions = []
|
|
pixel_values = []
|
|
|
|
for example in examples:
|
|
captions.append(example["blip2_caption"])
|
|
|
|
image = Image.open(BytesIO(example["jpg"])).convert('RGB')
|
|
image = image_transforms(image) * 2 - 1
|
|
image = torch.clamp(torch.nan_to_num(image), min=-1, max=1)
|
|
pixel_values.append(image)
|
|
|
|
pixel_values = torch.stack(pixel_values, dim=0).contiguous()
|
|
return pixel_values, captions
|
|
|
|
train_dataloader = DataLoader(
|
|
dataset = train_dataset,
|
|
batch_size = args.batch_size,
|
|
collate_fn = collate_fn,
|
|
num_workers = 0,
|
|
)
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="tokenizer")
|
|
text_encoder = CLIPTextModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="text_encoder")
|
|
text_encoder = text_encoder.to(dtype=torch.bfloat16, device=device)
|
|
text_encoder.requires_grad_(False)
|
|
text_encoder.eval()
|
|
|
|
vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
|
|
vae = vae.to(dtype=torch.bfloat16, device=device)
|
|
vae.requires_grad_(False)
|
|
vae.eval()
|
|
|
|
unet = UNet2DConditionModel.from_pretrained(args.unet).to(device)
|
|
unet.requires_grad_(True)
|
|
unet.enable_gradient_checkpointing()
|
|
unet.train()
|
|
|
|
optimizer = bnb.optim.AdamW8bit(
|
|
unet.parameters(),
|
|
lr = args.base_lr * (args.batch_size ** 0.5),
|
|
)
|
|
|
|
global_step = 0
|
|
train_logs = {"train_step": [], "train_loss": [], "train_timestep": []}
|
|
|
|
def encode_captions(captions):
|
|
input_ids = []
|
|
for caption in captions:
|
|
if torch.rand(1) < args.dropout:
|
|
caption = ""
|
|
ids = tokenizer(
|
|
caption,
|
|
max_length=tokenizer.model_max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
).input_ids
|
|
input_ids.append(ids)
|
|
input_ids = torch.stack(input_ids, dim=0).to(device)
|
|
return text_encoder(input_ids, return_dict=False)[0].float()
|
|
|
|
def vae_encode(pixels):
|
|
latents = vae.encode(pixels.to(dtype=torch.bfloat16, device=device)).latent_dist.sample()
|
|
return latents.float() * vae.config.scaling_factor
|
|
|
|
def get_pred(batch, log_to=None):
|
|
pixels, captions = batch
|
|
encoder_hidden_states = encode_captions(captions)
|
|
latents = vae_encode(pixels)
|
|
|
|
sigmas = torch.rand(latents.shape[0]).to(device)
|
|
sigmas = (args.shift * sigmas) / (1 + (args.shift - 1) * sigmas)
|
|
timesteps = sigmas * 1000
|
|
sigmas = sigmas[:, None, None, None]
|
|
|
|
noise = torch.randn_like(latents)
|
|
noisy_latents = noise * sigmas + latents * (1 - sigmas)
|
|
target = noise - latents
|
|
|
|
pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
|
|
|
|
loss = F.mse_loss(pred.float(), target.float(), reduction="none")
|
|
loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
|
|
|
if log_to is not None:
|
|
for i in range(timesteps.shape[0]):
|
|
log_to["train_step"].append(global_step)
|
|
log_to["train_loss"].append(loss[i].item())
|
|
log_to["train_timestep"].append(timesteps[i].item())
|
|
|
|
return loss.mean()
|
|
|
|
def plot_logs(log_dict):
|
|
plt.scatter(log_dict["train_timestep"], log_dict["train_loss"], s=3, c=log_dict["train_step"], marker=".", cmap='cool')
|
|
plt.xlabel("timestep")
|
|
plt.ylabel("loss")
|
|
plt.yscale("log")
|
|
|
|
progress_bar = tqdm(range(0, args.max_train_steps))
|
|
while True:
|
|
for step, batch in enumerate(train_dataloader):
|
|
loss = get_pred(batch, log_to=train_logs)
|
|
t_writer.add_scalar("train/loss", loss.detach().item(), global_step)
|
|
loss.backward()
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), 2.0)
|
|
t_writer.add_scalar("train/grad_norm", grad_norm.detach().item(), global_step)
|
|
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
progress_bar.update(1)
|
|
global_step += 1
|
|
|
|
if global_step % 100 == 0:
|
|
plot_logs(train_logs)
|
|
t_writer.add_figure("train_loss", plt.gcf(), global_step)
|
|
|
|
if global_step >= args.max_train_steps or global_step % args.checkpointing_steps == 0:
|
|
checkpoint_path = os.path.join(real_output_dir, f"checkpoint-{global_step:08}")
|
|
unet.save_pretrained(os.path.join(checkpoint_path, "unet"), safe_serialization=True)
|
|
|
|
if global_step >= args.max_train_steps:
|
|
break
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train(parse_args()) |