Spaces:
Running
on
T4
Running
on
T4
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
# Power by Zongsheng Yue 2022-05-18 13:04:06 | |
import os | |
import sys | |
import math | |
import time | |
import lpips | |
import random | |
import datetime | |
import functools | |
import numpy as np | |
from pathlib import Path | |
from loguru import logger | |
from copy import deepcopy | |
from omegaconf import OmegaConf | |
from collections import OrderedDict | |
from einops import rearrange | |
from datapipe.datasets import create_dataset | |
from models.resample import UniformSampler | |
import torch | |
import torch.nn as nn | |
import torch.cuda.amp as amp | |
import torch.nn.functional as F | |
import torch.utils.data as udata | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
import torchvision.utils as vutils | |
from torch.utils.tensorboard import SummaryWriter | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from utils import util_net | |
from utils import util_common | |
from utils import util_image | |
from basicsr.utils import DiffJPEG | |
from basicsr.utils.img_process_util import filter2D | |
from basicsr.data.transforms import paired_random_crop | |
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt | |
class TrainerBase: | |
def __init__(self, configs): | |
self.configs = configs | |
# setup distributed training: self.num_gpus, self.rank | |
self.setup_dist() | |
# setup seed | |
self.setup_seed() | |
# setup logger: self.logger | |
self.init_logger() | |
# logging the configurations | |
if self.rank == 0: self.logger.info(OmegaConf.to_yaml(self.configs)) | |
# build model: self.model, self.loss | |
self.build_model() | |
# setup optimization: self.optimzer, self.sheduler | |
self.setup_optimizaton() | |
# resume | |
self.resume_from_ckpt() | |
def setup_dist(self): | |
if self.configs.gpu_id: | |
gpu_id = self.configs.gpu_id | |
num_gpus = len(gpu_id) | |
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' | |
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([gpu_id[ii] for ii in range(num_gpus)]) | |
else: | |
num_gpus = torch.cuda.device_count() | |
if num_gpus > 1: | |
if mp.get_start_method(allow_none=True) is None: | |
mp.set_start_method('spawn') | |
rank = int(os.environ['LOCAL_RANK']) | |
torch.cuda.set_device(rank % num_gpus) | |
dist.init_process_group( | |
backend='nccl', | |
init_method='env://', | |
) | |
self.num_gpus = num_gpus | |
self.rank = int(os.environ['LOCAL_RANK']) if num_gpus > 1 else 0 | |
def setup_seed(self, seed=None): | |
seed = self.configs.seed if seed is None else seed | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def init_logger(self): | |
# only should be run on rank: 0 | |
save_dir = Path(self.configs.save_dir) | |
logtxet_path = save_dir / 'training.log' | |
log_dir = save_dir / 'logs' | |
ckpt_dir = save_dir / 'ckpts' | |
self.ckpt_dir = ckpt_dir | |
if self.rank == 0: | |
if not save_dir.exists(): | |
save_dir.mkdir() | |
else: | |
assert self.configs.resume, '''Please check the resume parameter. If you do not | |
want to resume from some checkpoint, please delete | |
the saving folder first.''' | |
# text logging | |
if logtxet_path.exists(): | |
assert self.configs.resume | |
self.logger = logger | |
self.logger.remove() | |
self.logger.add(logtxet_path, format="{message}", mode='a') | |
self.logger.add(sys.stderr, format="{message}") | |
# tensorboard log | |
if not log_dir.exists(): | |
log_dir.mkdir() | |
self.writer = SummaryWriter(str(log_dir)) | |
self.log_step = {phase: 1 for phase in ['train', 'val']} | |
self.log_step_img = {phase: 1 for phase in ['train', 'val']} | |
if not ckpt_dir.exists(): | |
ckpt_dir.mkdir() | |
def close_logger(self): | |
if self.rank == 0: self.writer.close() | |
def resume_from_ckpt(self): | |
if self.configs.resume: | |
if type(self.configs.resume) == bool: | |
ckpt_index = max([int(x.stem.split('_')[1]) for x in Path(self.ckpt_dir).glob('*.pth')]) | |
ckpt_path = str(Path(self.ckpt_dir) / f"model_{ckpt_index}.pth") | |
else: | |
ckpt_path = self.configs.resume | |
assert os.path.isfile(ckpt_path) | |
if self.rank == 0: | |
self.logger.info(f"=> Loaded checkpoint {ckpt_path}") | |
ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}") | |
util_net.reload_model(self.model, ckpt['state_dict']) | |
torch.cuda.empty_cache() | |
# iterations | |
self.iters_start = ckpt['iters_start'] | |
# learning rate scheduler | |
for ii in range(self.iters_start): self.adjust_lr(ii) | |
if self.rank == 0: | |
self.log_step = ckpt['log_step'] | |
self.log_step_img = ckpt['log_step_img'] | |
# reset the seed | |
self.setup_seed(self.iters_start) | |
else: | |
self.iters_start = 0 | |
def setup_optimizaton(self): | |
self.optimizer = torch.optim.AdamW(self.model.parameters(), | |
lr=self.configs.train.lr, | |
weight_decay=self.configs.train.weight_decay) | |
def build_model(self): | |
params = self.configs.model.get('params', dict) | |
model = util_common.get_obj_from_str(self.configs.model.target)(**params) | |
if self.num_gpus > 1: | |
self.model = DDP(model.cuda(), device_ids=[self.rank,]) # wrap the network | |
else: | |
self.model = model.cuda() | |
# LPIPS metric | |
if self.rank == 0: | |
self.lpips_loss = lpips.LPIPS(net='vgg').cuda() | |
# model information | |
self.print_model_info() | |
def build_dataloader(self): | |
def _wrap_loader(loader): | |
while True: yield from loader | |
datasets = {} | |
for phase in ['train', ]: | |
dataset_config = self.configs.data.get(phase, dict) | |
datasets[phase] = create_dataset(dataset_config) | |
dataloaders = {} | |
# train dataloader | |
if self.rank == 0: | |
for phase in ['train',]: | |
length = len(datasets[phase]) | |
self.logger.info('Number of images in {:s} data set: {:d}'.format(phase, length)) | |
if self.num_gpus > 1: | |
shuffle = False | |
sampler = udata.distributed.DistributedSampler(datasets['train'], | |
num_replicas=self.num_gpus, | |
rank=self.rank) | |
else: | |
shuffle = True | |
sampler = None | |
dataloaders['train'] = _wrap_loader(udata.DataLoader( | |
datasets['train'], | |
batch_size=self.configs.train.batch[0] // self.num_gpus, | |
shuffle=shuffle, | |
drop_last=False, | |
num_workers=self.configs.train.num_workers // self.num_gpus, | |
pin_memory=True, | |
prefetch_factor=self.configs.train.prefetch_factor, | |
worker_init_fn=my_worker_init_fn, | |
sampler=sampler)) | |
self.datasets = datasets | |
self.dataloaders = dataloaders | |
self.sampler = sampler | |
def print_model_info(self): | |
if self.rank == 0: | |
num_params = util_net.calculate_parameters(self.model) / 1000**2 | |
self.logger.info("Detailed network architecture:") | |
self.logger.info(self.model.__repr__()) | |
self.logger.info(f"Number of parameters: {num_params:.2f}M") | |
def prepare_data(self, phase='train'): | |
pass | |
def validation(self): | |
pass | |
def train(self): | |
self.build_dataloader() # prepare data: self.dataloaders, self.datasets, self.sampler | |
self.model.train() | |
num_iters_epoch = math.ceil(len(self.datasets['train']) / self.configs.train.batch[0]) | |
for ii in range(self.iters_start, self.configs.train.iterations): | |
self.current_iters = ii + 1 | |
# prepare data | |
data = self.prepare_data( | |
next(self.dataloaders['train']), | |
self.configs.data.train.type.lower() == 'realesrgan', | |
) | |
# training phase | |
self.training_step(data) | |
# validation phase | |
if (ii+1) % self.configs.train.val_freq == 0 and 'val' in self.dataloaders: | |
if self.rank==0: | |
self.validation() | |
#update learning rate | |
self.adjust_lr() | |
# save checkpoint | |
if (ii+1) % self.configs.train.save_freq == 0 and self.rank == 0: | |
self.save_ckpt() | |
if (ii+1) % num_iters_epoch == 0 and not self.sampler is None: | |
self.sampler.set_epoch(ii+1) | |
# close the tensorboard | |
if self.rank == 0: | |
self.close_logger() | |
def training_step(self, data): | |
pass | |
def adjust_lr(self): | |
if hasattr(self, 'lr_sheduler'): | |
self.lr_sheduler.step() | |
def save_ckpt(self): | |
ckpt_path = self.ckpt_dir / 'model_{:d}.pth'.format(self.current_iters) | |
torch.save({'iters_start': self.current_iters, | |
'log_step': {phase:self.log_step[phase] for phase in ['train', 'val']}, | |
'log_step_img': {phase:self.log_step_img[phase] for phase in ['train', 'val']}, | |
'state_dict': self.model.state_dict()}, ckpt_path) | |
class TrainerSR(TrainerBase): | |
def __init__(self, configs): | |
super().__init__(configs) | |
def mse_loss(self, pred, target): | |
return F.mse_loss(pred, target, reduction='mean') | |
def _dequeue_and_enqueue(self): | |
"""It is the training pair pool for increasing the diversity in a batch. | |
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a | |
batch could not have different resize scaling factors. Therefore, we employ this training pair pool | |
to increase the degradation diversity in a batch. | |
""" | |
# initialize | |
b, c, h, w = self.lq.size() | |
if not hasattr(self, 'queue_size'): | |
self.queue_size = self.configs.data.train.params.get('queue_size', b*50) | |
if not hasattr(self, 'queue_lr'): | |
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' | |
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() | |
_, c, h, w = self.gt.size() | |
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() | |
self.queue_ptr = 0 | |
if self.queue_ptr == self.queue_size: # the pool is full | |
# do dequeue and enqueue | |
# shuffle | |
idx = torch.randperm(self.queue_size) | |
self.queue_lr = self.queue_lr[idx] | |
self.queue_gt = self.queue_gt[idx] | |
# get first b samples | |
lq_dequeue = self.queue_lr[0:b, :, :, :].clone() | |
gt_dequeue = self.queue_gt[0:b, :, :, :].clone() | |
# update the queue | |
self.queue_lr[0:b, :, :, :] = self.lq.clone() | |
self.queue_gt[0:b, :, :, :] = self.gt.clone() | |
self.lq = lq_dequeue | |
self.gt = gt_dequeue | |
else: | |
# only do enqueue | |
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() | |
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() | |
self.queue_ptr = self.queue_ptr + b | |
def prepare_data(self, data, real_esrgan=True): | |
if real_esrgan: | |
if not hasattr(self, 'jpeger'): | |
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts | |
im_gt = data['gt'].cuda() | |
kernel1 = data['kernel1'].cuda() | |
kernel2 = data['kernel2'].cuda() | |
sinc_kernel = data['sinc_kernel'].cuda() | |
ori_h, ori_w = im_gt.size()[2:4] | |
# ----------------------- The first degradation process ----------------------- # | |
# blur | |
out = filter2D(im_gt, kernel1) | |
# random resize | |
updown_type = random.choices( | |
['up', 'down', 'keep'], | |
self.configs.degradation['resize_prob'], | |
)[0] | |
if updown_type == 'up': | |
scale = random.uniform(1, self.configs.degradation['resize_range'][1]) | |
elif updown_type == 'down': | |
scale = random.uniform(self.configs.degradation['resize_range'][0], 1) | |
else: | |
scale = 1 | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate(out, scale_factor=scale, mode=mode) | |
# add noise | |
gray_noise_prob = self.configs.degradation['gray_noise_prob'] | |
if random.random() < self.configs.degradation['gaussian_noise_prob']: | |
out = random_add_gaussian_noise_pt( | |
out, | |
sigma_range=self.configs.degradation['noise_range'], | |
clip=True, | |
rounds=False, | |
gray_prob=gray_noise_prob, | |
) | |
else: | |
out = random_add_poisson_noise_pt( | |
out, | |
scale_range=self.configs.degradation['poisson_scale_range'], | |
gray_prob=gray_noise_prob, | |
clip=True, | |
rounds=False) | |
# JPEG compression | |
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range']) | |
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts | |
out = self.jpeger(out, quality=jpeg_p) | |
# ----------------------- The second degradation process ----------------------- # | |
# blur | |
if random.random() < self.configs.degradation['second_blur_prob']: | |
out = filter2D(out, kernel2) | |
# random resize | |
updown_type = random.choices( | |
['up', 'down', 'keep'], | |
self.configs.degradation['resize_prob2'], | |
)[0] | |
if updown_type == 'up': | |
scale = random.uniform(1, self.configs.degradation['resize_range2'][1]) | |
elif updown_type == 'down': | |
scale = random.uniform(self.configs.degradation['resize_range2'][0], 1) | |
else: | |
scale = 1 | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate( | |
out, | |
size=(int(ori_h / self.configs.model.params.sf * scale), | |
int(ori_w / self.configs.model.params.sf * scale)), | |
mode=mode, | |
) | |
# add noise | |
gray_noise_prob = self.configs.degradation['gray_noise_prob2'] | |
if random.random() < self.configs.degradation['gaussian_noise_prob2']: | |
out = random_add_gaussian_noise_pt( | |
out, | |
sigma_range=self.configs.degradation['noise_range2'], | |
clip=True, | |
rounds=False, | |
gray_prob=gray_noise_prob, | |
) | |
else: | |
out = random_add_poisson_noise_pt( | |
out, | |
scale_range=self.configs.degradation['poisson_scale_range2'], | |
gray_prob=gray_noise_prob, | |
clip=True, | |
rounds=False, | |
) | |
# JPEG compression + the final sinc filter | |
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together | |
# as one operation. | |
# We consider two orders: | |
# 1. [resize back + sinc filter] + JPEG compression | |
# 2. JPEG compression + [resize back + sinc filter] | |
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. | |
if random.random() < 0.5: | |
# resize back + the final sinc filter | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate( | |
out, | |
size=(ori_h // self.configs.model.params.sf, | |
ori_w // self.configs.model.params.sf), | |
mode=mode, | |
) | |
out = filter2D(out, sinc_kernel) | |
# JPEG compression | |
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2']) | |
out = torch.clamp(out, 0, 1) | |
out = self.jpeger(out, quality=jpeg_p) | |
else: | |
# JPEG compression | |
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2']) | |
out = torch.clamp(out, 0, 1) | |
out = self.jpeger(out, quality=jpeg_p) | |
# resize back + the final sinc filter | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate( | |
out, | |
size=(ori_h // self.configs.model.params.sf, | |
ori_w // self.configs.model.params.sf), | |
mode=mode, | |
) | |
out = filter2D(out, sinc_kernel) | |
# clamp and round | |
im_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. | |
# random crop | |
gt_size = self.configs.degradation['gt_size'] | |
im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, self.configs.model.params.sf) | |
self.lq, self.gt = im_lq, im_gt | |
# training pair pool | |
self._dequeue_and_enqueue() | |
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue | |
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract | |
return {'lq':self.lq, 'gt':self.gt} | |
else: | |
return {key:value.cuda() for key, value in data.items()} | |
def setup_optimizaton(self): | |
super().setup_optimizaton() # self.optimizer | |
self.lr_sheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
self.optimizer, | |
T_max = self.configs.train.iterations, | |
eta_min=self.configs.train.lr_min, | |
) | |
def training_step(self, data): | |
current_batchsize = data['lq'].shape[0] | |
micro_batchsize = self.configs.train.microbatch | |
num_grad_accumulate = math.ceil(current_batchsize / micro_batchsize) | |
self.optimizer.zero_grad() | |
for jj in range(0, current_batchsize, micro_batchsize): | |
micro_data = {key:value[jj:jj+micro_batchsize,] for key, value in data.items()} | |
last_batch = (jj+micro_batchsize >= current_batchsize) | |
hq_pred = self.model(micro_data['lq']) | |
if last_batch or self.num_gpus <= 1: | |
loss = self.loss_fun(hq_pred, micro_data['gt']) / hq_pred.shape[0] | |
else: | |
with self.model.no_sync(): | |
loss = self.loss_fun(hq_pred, micro_data['gt']) / hq_pred.shape[0] | |
loss /= num_grad_accumulate | |
loss.backward() | |
# make logging | |
self.log_step_train(hq_pred, loss, micro_data, flag=last_batch) | |
self.optimizer.step() | |
def log_step_train(self, hq_pred, loss, batch, flag=False, phase='train'): | |
''' | |
param loss: loss value | |
''' | |
if self.rank == 0: | |
chn = batch['lq'].shape[1] | |
if self.current_iters % self.configs.train.log_freq[0] == 1: | |
self.loss_mean = 0 | |
self.loss_mean += loss.item() | |
if self.current_iters % self.configs.train.log_freq[0] == 0 and flag: | |
self.loss_mean /= self.configs.train.log_freq[0] | |
mse_pixel = self.loss_mean / batch['gt'].numel() * batch['gt'].shape[0] | |
log_str = 'Train:{:05d}/{:05d}, Loss:{:.2e}, MSE:{:.2e}, lr:{:.2e}'.format( | |
self.current_iters // 100, | |
self.configs.train.iterations // 100, | |
self.loss_mean, | |
mse_pixel, | |
self.optimizer.param_groups[0]['lr'] | |
) | |
self.logger.info(log_str) | |
# tensorboard | |
self.writer.add_scalar(f'Loss-Train', self.loss_mean, self.log_step[phase]) | |
self.log_step[phase] += 1 | |
if self.current_iters % self.configs.train.log_freq[1] == 0 and flag: | |
x1 = vutils.make_grid(batch['lq'], normalize=True, scale_each=True) | |
self.writer.add_image("Train LQ Image", x1, self.log_step_img[phase]) | |
x2 = vutils.make_grid(batch['gt'], normalize=True, scale_each=True) | |
self.writer.add_image("Train HQ Image", x2, self.log_step_img[phase]) | |
x3 = vutils.make_grid(hq_pred.detach().data, normalize=True, scale_each=True) | |
self.writer.add_image("Train Recovered Image", x3, self.log_step_img[phase]) | |
self.log_step_img[phase] += 1 | |
if self.current_iters % self.configs.train.save_freq == 1 and flag: | |
self.tic = time.time() | |
if self.current_iters % self.configs.train.save_freq == 0 and flag: | |
self.toc = time.time() | |
elaplsed = (self.toc - self.tic) | |
self.logger.info(f"Elapsed time: {elaplsed:.2f}s") | |
self.logger.info("="*60) | |
def validation(self, phase='val'): | |
if self.rank == 0: | |
self.model.eval() | |
psnr_mean = lpips_mean = 0 | |
total_iters = math.ceil(len(self.datasets[phase]) / self.configs.train.batch[1]) | |
for ii, data in enumerate(self.dataloaders[phase]): | |
data = self.prepare_data(data) | |
with torch.no_grad(): | |
hq_pred = self.model(data['lq']) | |
hq_pred.clamp_(0.0, 1.0) | |
lpips = self.lpips_loss( | |
util_image.normalize_th(hq_pred, reverse=False), | |
util_image.normalize_th(data['gt'], reverse=False), | |
).sum().item() | |
psnr = util_image.batch_PSNR( | |
hq_pred, | |
data['gt'], | |
ycbcr=True | |
) | |
psnr_mean += psnr | |
lpips_mean += lpips | |
if (ii+1) % self.configs.train.log_freq[2] == 0: | |
log_str = '{:s}:{:03d}/{:03d}, PSNR={:5.2f}, LPIPS={:6.4f}'.format( | |
phase, | |
ii+1, | |
total_iters, | |
psnr / hq_pred.shape[0], | |
lpips / hq_pred.shape[0] | |
) | |
self.logger.info(log_str) | |
x1 = vutils.make_grid(data['lq'], normalize=True, scale_each=True) | |
self.writer.add_image("Validation LQ Image", x1, self.log_step_img[phase]) | |
x2 = vutils.make_grid(data['gt'], normalize=True, scale_each=True) | |
self.writer.add_image("Validation HQ Image", x2, self.log_step_img[phase]) | |
x3 = vutils.make_grid(hq_pred.detach().data, normalize=True, scale_each=True) | |
self.writer.add_image("Validation Recovered Image", x3, self.log_step_img[phase]) | |
self.log_step_img[phase] += 1 | |
psnr_mean /= len(self.datasets[phase]) | |
lpips_mean /= len(self.datasets[phase]) | |
# tensorboard | |
self.writer.add_scalar('Validation PSRN', psnr_mean, self.log_step[phase]) | |
self.writer.add_scalar('Validation LPIPS', lpips_mean, self.log_step[phase]) | |
self.log_step[phase] += 1 | |
# logging | |
self.logger.info(f'PSNR={psnr_mean:5.2f}, LPIPS={lpips_mean:6.4f}') | |
self.logger.info("="*60) | |
self.model.train() | |
def build_dataloader(self): | |
super().build_dataloader() | |
if self.rank == 0 and 'val' in self.configs.data: | |
dataset_config = self.configs.data.get('val', dict) | |
self.datasets['val'] = create_dataset(dataset_config) | |
self.dataloaders['val'] = udata.DataLoader( | |
self.datasets['val'], | |
batch_size=self.configs.train.batch[1], | |
shuffle=False, | |
drop_last=False, | |
num_workers=0, | |
pin_memory=True, | |
) | |
class TrainerDiffusionFace(TrainerBase): | |
def __init__(self, configs): | |
# ema settings | |
self.ema_rates = OmegaConf.to_object(configs.train.ema_rates) | |
super().__init__(configs) | |
def init_logger(self): | |
super().init_logger() | |
save_dir = Path(self.configs.save_dir) | |
ema_ckpt_dir = save_dir / 'ema_ckpts' | |
if self.rank == 0: | |
if not ema_ckpt_dir.exists(): | |
util_common.mkdir(ema_ckpt_dir, delete=False, parents=False) | |
else: | |
if not self.configs.resume: | |
util_common.mkdir(ema_ckpt_dir, delete=True, parents=False) | |
self.ema_ckpt_dir = ema_ckpt_dir | |
def resume_from_ckpt(self): | |
super().resume_from_ckpt() | |
def _load_ema_state(ema_state, ckpt): | |
for key in ema_state.keys(): | |
ema_state[key] = deepcopy(ckpt[key].detach().data) | |
if self.configs.resume: | |
# ema model | |
if type(self.configs.resume) == bool: | |
ckpt_index = max([int(x.stem.split('_')[1]) for x in Path(self.ckpt_dir).glob('*.pth')]) | |
ckpt_path = str(Path(self.ckpt_dir) / f"model_{ckpt_index}.pth") | |
else: | |
ckpt_path = self.configs.resume | |
assert os.path.isfile(ckpt_path) | |
# EMA model | |
for rate in self.ema_rates: | |
ema_ckpt_path = self.ema_ckpt_dir / (f"ema0{int(rate*1000)}_"+Path(ckpt_path).name) | |
ema_ckpt = torch.load(ema_ckpt_path, map_location=f"cuda:{self.rank}") | |
_load_ema_state(self.ema_state[f"0{int(rate*1000)}"], ema_ckpt) | |
def build_model(self): | |
params = self.configs.model.get('params', dict) | |
model = util_common.get_obj_from_str(self.configs.model.target)(**params) | |
self.ema_model = deepcopy(model.cuda()) | |
if self.num_gpus > 1: | |
self.model = DDP(model.cuda(), device_ids=[self.rank,]) # wrap the network | |
else: | |
self.model = model.cuda() | |
self.ema_state = {} | |
for rate in self.ema_rates: | |
self.ema_state[f"0{int(rate*1000)}"] = OrderedDict( | |
{key:deepcopy(value.data) for key, value in self.model.state_dict().items()} | |
) | |
# model information | |
self.print_model_info() | |
params = self.configs.diffusion.get('params', dict) | |
self.base_diffusion = util_common.get_obj_from_str(self.configs.diffusion.target)(**params) | |
self.sample_scheduler_diffusion = UniformSampler(self.base_diffusion.num_timesteps) | |
def prepare_data(self, data, realesrgan=False): | |
data = {key:value.cuda() for key, value in data.items()} | |
return data | |
def training_step(self, data): | |
current_batchsize = data['image'].shape[0] | |
micro_batchsize = self.configs.train.microbatch | |
num_grad_accumulate = math.ceil(current_batchsize / micro_batchsize) | |
if self.configs.train.use_fp16: | |
scaler = amp.GradScaler() | |
self.optimizer.zero_grad() | |
for jj in range(0, current_batchsize, micro_batchsize): | |
micro_data = {key:value[jj:jj+micro_batchsize,] for key, value in data.items()} | |
last_batch = (jj+micro_batchsize >= current_batchsize) | |
tt, weights = self.sample_scheduler_diffusion.sample( | |
micro_data['image'].shape[0], | |
device=f"cuda:{self.rank}", | |
use_fp16=self.configs.train.use_fp16 | |
) | |
compute_losses = functools.partial( | |
self.base_diffusion.training_losses, | |
self.model, | |
micro_data['image'], | |
tt, | |
model_kwargs={'y':micro_data['label']} if 'label' in micro_data else None, | |
) | |
if self.configs.train.use_fp16: | |
with amp.autocast(): | |
if last_batch or self.num_gpus <= 1: | |
losses = compute_losses() | |
else: | |
with self.model.no_sync(): | |
losses = compute_losses() | |
loss = (losses["loss"] * weights).mean() / num_grad_accumulate | |
scaler.scale(loss).backward() | |
else: | |
if last_batch or self.num_gpus <= 1: | |
losses = compute_losses() | |
else: | |
with self.model.no_sync(): | |
losses = compute_losses() | |
loss = (losses["loss"] * weights).mean() / num_grad_accumulate | |
loss.backward() | |
# make logging | |
self.log_step_train(losses, tt, micro_data, last_batch) | |
if self.configs.train.use_fp16: | |
scaler.step(self.optimizer) | |
scaler.update() | |
else: | |
self.optimizer.step() | |
self.update_ema_model() | |
def update_ema_model(self): | |
if self.num_gpus > 1: | |
dist.barrier() | |
if self.rank == 0: | |
for rate in self.ema_rates: | |
ema_state = self.ema_state[f"0{int(rate*1000)}"] | |
source_state = self.model.state_dict() | |
for key, value in ema_state.items(): | |
ema_state[key].mul_(rate).add_(source_state[key].detach().data, alpha=1-rate) | |
def adjust_lr(self, ii): | |
base_lr = self.configs.train.lr | |
linear_steps = self.configs.train.milestones[0] | |
if ii <= linear_steps: | |
for params_group in self.optimizer.param_groups: | |
params_group['lr'] = (ii / linear_steps) * base_lr | |
elif ii in self.configs.train.milestones: | |
for params_group in self.optimizer.param_groups: | |
params_group['lr'] *= 0.5 | |
def log_step_train(self, loss, tt, batch, flag=False, phase='train'): | |
''' | |
param loss: a dict recording the loss informations | |
param tt: 1-D tensor, time steps | |
''' | |
if self.rank == 0: | |
chn = batch['image'].shape[1] | |
num_timesteps = self.base_diffusion.num_timesteps | |
if self.current_iters % self.configs.train.log_freq[0] == 1: | |
self.loss_mean = {key:torch.zeros(size=(num_timesteps,), dtype=torch.float64) | |
for key in loss.keys()} | |
self.loss_count = torch.zeros(size=(num_timesteps,), dtype=torch.float64) | |
for key, value in loss.items(): | |
self.loss_mean[key][tt, ] += value.detach().data.cpu() | |
self.loss_count[tt,] += 1 | |
if self.current_iters % self.configs.train.log_freq[0] == 0 and flag: | |
if torch.any(self.loss_count == 0): | |
self.loss_count += 1e-4 | |
for key, value in loss.items(): | |
self.loss_mean[key] /= self.loss_count | |
log_str = 'Train: {:05d}/{:05d}, Loss: '.format( | |
self.current_iters // 100, | |
self.configs.train.iterations // 100) | |
for kk in [1, num_timesteps // 2, num_timesteps]: | |
if 'vb' in self.loss_mean: | |
log_str += 't({:d}):{:.2e}/{:.2e}/{:.2e}, '.format( | |
kk, | |
self.loss_mean['loss'][kk-1].item(), | |
self.loss_mean['mse'][kk-1].item(), | |
self.loss_mean['vb'][kk-1].item(), | |
) | |
else: | |
log_str += 't({:d}):{:.2e}, '.format(kk, self.loss_mean['loss'][kk-1].item()) | |
log_str += 'lr:{:.2e}'.format(self.optimizer.param_groups[0]['lr']) | |
self.logger.info(log_str) | |
# tensorboard | |
for kk in [1, num_timesteps // 2, num_timesteps]: | |
self.writer.add_scalar(f'Loss-Step-{kk}', | |
self.loss_mean['loss'][kk-1].item(), | |
self.log_step[phase]) | |
self.log_step[phase] += 1 | |
if self.current_iters % self.configs.train.log_freq[1] == 0 and flag: | |
x1 = vutils.make_grid(batch['image'], normalize=True, scale_each=True) | |
self.writer.add_image("Training Image", x1, self.log_step_img[phase]) | |
self.log_step_img[phase] += 1 | |
if self.current_iters % self.configs.train.save_freq == 1 and flag: | |
self.tic = time.time() | |
if self.current_iters % self.configs.train.save_freq == 0 and flag: | |
self.toc = time.time() | |
elaplsed = (self.toc - self.tic) * num_timesteps / (num_timesteps - 1) | |
self.logger.info(f"Elapsed time: {elaplsed:.2f}s") | |
self.logger.info("="*130) | |
def validation(self, phase='val'): | |
self.reload_ema_model(self.ema_rates[0]) | |
self.ema_model.eval() | |
indices = [int(self.base_diffusion.num_timesteps * x) for x in [0.25, 0.5, 0.75, 1]] | |
chn = 3 | |
batch_size = self.configs.train.batch[1] | |
shape = (batch_size, chn,) + (self.configs.data.train.params.out_size,) * 2 | |
num_iters = 0 | |
# noise = torch.randn(shape, | |
# dtype=torch.float32, | |
# generator=torch.Generator('cpu').manual_seed(10000)).cuda() | |
for sample in self.base_diffusion.p_sample_loop_progressive( | |
model = self.ema_model, | |
shape = shape, | |
noise = None, | |
clip_denoised = True, | |
model_kwargs = None, | |
device = f"cuda:{self.rank}", | |
progress=False | |
): | |
num_iters += 1 | |
img = util_image.normalize_th(sample['sample'], reverse=True) | |
if num_iters == 1: | |
im_recover = img | |
elif num_iters in indices: | |
im_recover_last = img | |
im_recover = torch.cat((im_recover, im_recover_last), dim=1) | |
im_recover = rearrange(im_recover, 'b (k c) h w -> (b k) c h w', c=chn) | |
x1 = vutils.make_grid(im_recover, nrow=len(indices)+1, normalize=False) | |
self.writer.add_image('Validation Sample', x1, self.log_step_img[phase]) | |
self.log_step_img[phase] += 1 | |
def save_ckpt(self): | |
if self.rank == 0: | |
ckpt_path = self.ckpt_dir / 'model_{:d}.pth'.format(self.current_iters) | |
torch.save({'iters_start': self.current_iters, | |
'log_step': {phase:self.log_step[phase] for phase in ['train', 'val']}, | |
'log_step_img': {phase:self.log_step_img[phase] for phase in ['train', 'val']}, | |
'state_dict': self.model.state_dict()}, ckpt_path) | |
for rate in self.ema_rates: | |
ema_ckpt_path = self.ema_ckpt_dir / (f"ema0{int(rate*1000)}_"+ckpt_path.name) | |
torch.save(self.ema_state[f"0{int(rate*1000)}"], ema_ckpt_path) | |
def calculate_lpips(self, inputs, targets): | |
inputs, targets = [(x-0.5)/0.5 for x in [inputs, targets]] # [-1, 1] | |
with torch.no_grad(): | |
mean_lpips = self.lpips_loss(inputs, targets) | |
return mean_lpips.mean().item() | |
def reload_ema_model(self, rate): | |
model_state = {key[7:]:value for key, value in self.ema_state[f"0{int(rate*1000)}"].items()} | |
self.ema_model.load_state_dict(model_state) | |
def my_worker_init_fn(worker_id): | |
np.random.seed(np.random.get_state()[1][0] + worker_id) | |
if __name__ == '__main__': | |
from utils import util_image | |
from einops import rearrange | |
im1 = util_image.imread('./testdata/inpainting/val/places/Places365_val_00012685_crop000.png', | |
chn = 'rgb', dtype='float32') | |
im2 = util_image.imread('./testdata/inpainting/val/places/Places365_val_00014886_crop000.png', | |
chn = 'rgb', dtype='float32') | |
im = rearrange(np.stack((im1, im2), 3), 'h w c b -> b c h w') | |
im_grid = im.copy() | |
for alpha in [0.8, 0.4, 0.1, 0]: | |
im_new = im * alpha + np.random.randn(*im.shape) * (1 - alpha) | |
im_grid = np.concatenate((im_new, im_grid), 1) | |
im_grid = np.clip(im_grid, 0.0, 1.0) | |
im_grid = rearrange(im_grid, 'b (k c) h w -> (b k) c h w', k=5) | |
xx = vutils.make_grid(torch.from_numpy(im_grid), nrow=5, normalize=True, scale_each=True).numpy() | |
util_image.imshow(np.concatenate((im1, im2), 0)) | |
util_image.imshow(xx.transpose((1,2,0))) | |