import copy import functools import os import blobfile as bf import numpy as np import torch as th import torch.distributed as dist from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.optim import AdamW from .glide_util import sample from . import dist_util, logger from .fp16_util import ( make_master_params, master_params_to_model_params, model_grads_to_master_grads, unflatten_master_params, zero_grad, ) from .nn import update_ema from .vgg import VGG from .adv import AdversarialLoss from .resample import LossAwareSampler, UniformSampler import glob import torchvision.utils as tvu import PIL.Image as Image # For ImageNet experiments, this was a good default value. # We found that the lg_loss_scale quickly climbed to # 20-21 within the first ~1K steps of training. INITIAL_LOG_LOSS_SCALE = 20.0 class TrainLoop: def __init__( self, model, glide_options, diffusion, data, val_data, batch_size, microbatch, lr, ema_rate, log_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=1e-3, schedule_sampler=None, weight_decay=0.0, lr_anneal_steps=0, finetune_decoder = False, mode = '', use_vgg = False, use_gan = False, uncond_p = 0, super_res = 0, ): self.model = model self.glide_options=glide_options self.diffusion = diffusion self.data = data self.val_data=val_data self.batch_size = batch_size self.microbatch = microbatch if microbatch > 0 else batch_size self.lr = lr self.ema_rate = ( [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")] ) self.log_interval = log_interval self.save_interval = save_interval self.resume_checkpoint = find_resume_checkpoint(resume_checkpoint) self.use_fp16 = use_fp16 self.fp16_scale_growth = fp16_scale_growth self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) self.weight_decay = weight_decay self.lr_anneal_steps = lr_anneal_steps self.step = 0 self.resume_step = 0 self.global_batch = self.batch_size * dist.get_world_size() if use_vgg: self.vgg = VGG(conv_index='22').to(dist_util.dev()) print('use perc') else: self.vgg = None if use_gan: self.adv = AdversarialLoss() print('use adv') else: self.adv = None self.super_res = super_res self.uncond_p =uncond_p self.mode = mode self.finetune_decoder = finetune_decoder if finetune_decoder: self.optimize_model = self.model else: self.optimize_model = self.model.encoder self.model_params = list(self.optimize_model.parameters()) self.master_params = self.model_params self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE self.sync_cuda = th.cuda.is_available() self._load_and_sync_parameters() if self.use_fp16: self._setup_fp16() self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) if self.resume_step: self._load_optimizer_state() # Model was resumed, either due to a restart or a checkpoint # being specified at the command line. self.ema_params = [ self._load_ema_parameters(rate) for rate in self.ema_rate ] else: self.ema_params = [ copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) ] if th.cuda.is_available(): self.use_ddp = True self.ddp_model = DDP( self.model, device_ids=[dist_util.dev()], output_device=dist_util.dev(), broadcast_buffers=False, bucket_cap_mb=128, find_unused_parameters=False, ) else: if dist.get_world_size() > 1: logger.warn( "Distributed training requires CUDA. " "Gradients will not be synchronized properly!" ) self.use_ddp = False self.ddp_model = self.model def _load_and_sync_parameters(self): resume_checkpoint = self.resume_checkpoint if resume_checkpoint: self.resume_step = parse_resume_step_from_filename(resume_checkpoint) if dist.get_rank() == 0: logger.log(f"loading model from checkpoint: {resume_checkpoint}...") self.model.load_state_dict(th.load(resume_checkpoint, map_location="cpu"),strict=False) dist_util.sync_params(self.model.parameters()) def _load_ema_parameters(self, rate): ema_params = copy.deepcopy(self.master_params) main_checkpoint = self.resume_checkpoint ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) if ema_checkpoint: if dist.get_rank() == 0: logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") state_dict = th.load(ema_checkpoint, map_location=dist_util.dev()) ema_params = self._state_dict_to_master_params(state_dict) #dist_util.sync_params(ema_params) return ema_params def _load_optimizer_state(self): main_checkpoint = self.resume_checkpoint opt_checkpoint = bf.join( bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" ) if bf.exists(opt_checkpoint): logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") state_dict = th.load(opt_checkpoint, map_location="cpu") try: self.opt.load_state_dict(state_dict) except: pass def _setup_fp16(self): self.master_params = make_master_params(self.model_params) self.model.convert_to_fp16() def run_loop(self): while ( not self.lr_anneal_steps or self.step <= self.lr_anneal_steps ): batch, model_kwargs = next(self.data) # uncond_p = 0 # if self.super_res: # uncond_p = 0 # elif self.finetune_decoder: # uncond_p = self.uncond_p # elif self.step > self.lr_anneal_steps - 40000: # uncond_p = self.uncond_p self.run_step(batch, model_kwargs) if self.step % self.log_interval == 0: logger.dumpkvs() if self.step % self.save_interval == 0: self.save() self.val(self.step) self.step += 1 if (self.step - 1) % self.save_interval != 0: self.save() def run_step(self, batch, model_kwargs): self.forward_backward(batch, model_kwargs) if self.use_fp16: self.optimize_fp16() else: self.optimize_normal() self.log_step() def forward_backward(self, batch, model_kwargs): zero_grad(self.model_params) for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond={n:model_kwargs[n][i:i+self.microbatch].to(dist_util.dev()) for n in model_kwargs if n in ['ref', 'low_res']} last_batch = (i + self.microbatch) >= batch.shape[0] t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) if self.step <100: vgg_loss = None adv_loss = None else: vgg_loss = self.vgg adv_loss = self.adv compute_losses = functools.partial( self.diffusion.training_losses, self.ddp_model, micro, t, vgg_loss, adv_loss, model_kwargs=micro_cond, ) if last_batch or not self.use_ddp: losses = compute_losses() else: with self.ddp_model.no_sync(): losses = compute_losses() if isinstance(self.schedule_sampler, LossAwareSampler): self.schedule_sampler.update_with_local_losses( t, losses["loss"].detach() ) loss = (losses["loss"] * weights).mean() log_loss_dict( self.diffusion, t, {k: v * weights for k, v in losses.items()} ) if self.use_fp16: loss_scale = 2 ** self.lg_loss_scale (loss * loss_scale).backward() else: loss.backward() def val(self, step): inner_model=self.ddp_model.module inner_model.eval() if dist.get_rank() == 0: print("sampling...") s_path = os.path.join(logger.get_dir(), 'results') os.makedirs(s_path,exist_ok=True) img_id = 0 guidance_scale=self.glide_options['sample_c'] while (True): if img_id >= self.glide_options['num_samples']: break batch, model_kwargs = next(self.val_data) with th.no_grad(): samples=sample( glide_model=inner_model, glide_options=self.glide_options, side_x=self.glide_options['image_size'], side_y=self.glide_options['image_size'], prompt=model_kwargs, batch_size=self.glide_options['batch_size']//2, guidance_scale=guidance_scale, device=dist_util.dev(), prediction_respacing=self.glide_options['sample_respacing'], upsample_enabled=self.glide_options['super_res'], upsample_temp=0.997, mode = self.mode, ) samples = samples.cpu() ref = model_kwargs['ref_ori'] # LR = model_kwargs['low_res'].cpu() for i in range(samples.size(0)): out_path = os.path.join(s_path, f"{dist.get_rank()}_{img_id}_step{step}_{guidance_scale}_output.png") tvu.save_image( (samples[i]+1)*0.5, out_path) out_path = os.path.join(s_path, f"{dist.get_rank()}_{img_id}_step{step}_{guidance_scale}_gt.png") tvu.save_image( (batch[i]+1)*0.5, out_path) out_path = os.path.join(s_path, f"{dist.get_rank()}_{img_id}_step{step}_{guidance_scale}_ref.png") tvu.save_image( (ref[i]+1)*0.5, out_path) # out_path = os.path.join(s_path, f"{dist.get_rank()}_{img_id}_step{step}_{guidance_scale}_lr.png") # tvu.save_image( # (LR[i]+1)*0.5, out_path) img_id += 1 inner_model.train() def optimize_fp16(self): if any(not th.isfinite(p.grad).all() for p in self.model_params): self.lg_loss_scale -= 1 logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") return model_grads_to_master_grads(self.model_params, self.master_params) self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) self._log_grad_norm() self._anneal_lr() self.opt.step() for rate, params in zip(self.ema_rate, self.ema_params): update_ema(params, self.master_params, rate=rate) master_params_to_model_params(self.model_params, self.master_params) self.lg_loss_scale += self.fp16_scale_growth def optimize_normal(self): self._log_grad_norm() self._anneal_lr() self.opt.step() for rate, params in zip(self.ema_rate, self.ema_params): update_ema(params, self.master_params, rate=rate) def _log_grad_norm(self): sqsum = 0.0 for p in self.master_params: sqsum += (p.grad ** 2).sum().item() logger.logkv_mean("grad_norm", np.sqrt(sqsum)) def _anneal_lr(self): return def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) def save(self): def save_checkpoint(rate, params): state_dict = self._master_params_to_state_dict(params) if dist.get_rank() == 0: logger.log(f"saving model {rate}...") if not rate: filename = f"model{(self.step+self.resume_step):06d}.pt" else: filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: th.save(state_dict, f) save_checkpoint(0, self.master_params) for rate, params in zip(self.ema_rate, self.ema_params): save_checkpoint(rate, params) if dist.get_rank() == 0: with bf.BlobFile( bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), "wb", ) as f: th.save(self.opt.state_dict(), f) dist.barrier() def _master_params_to_state_dict(self, master_params): if self.use_fp16: master_params = unflatten_master_params( list(self.optimize_model.parameters()), master_params ) state_dict = self.optimize_model.state_dict() for i, (name, _value) in enumerate(self.optimize_model.named_parameters()): assert name in state_dict state_dict[name] = master_params[i] return state_dict def _state_dict_to_master_params(self, state_dict): params = [state_dict[name] for name, _ in self.optimize_model.named_parameters()] if self.use_fp16: return make_master_params(params) else: return params def parse_resume_step_from_filename(filename): """ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the checkpoint's number of steps. """ filename=filename.split('/')[-1] assert(filename.endswith(".pt")) filename=filename[:-3] if filename.startswith("model"): split = filename[5:] elif filename.startswith("ema"): split = filename.split("_")[-1] else: return 0 try: return int(split) except ValueError: return 0 def get_blob_logdir(): p=os.path.join(logger.get_dir(),"checkpoints") os.makedirs(p,exist_ok=True) return p def find_resume_checkpoint(resume_checkpoint): # On your infrastructure, you may want to override this to automatically # discover the latest checkpoint on your blob storage, etc. if not resume_checkpoint: return None if "ROOT" in resume_checkpoint: maybe_root=os.environ.get("AMLT_MAP_INPUT_DIR") maybe_root="OUTPUT/log" if not maybe_root else maybe_root root=os.path.join(maybe_root,"checkpoints") resume_checkpoint=resume_checkpoint.replace("ROOT",root) if "LATEST" in resume_checkpoint: files=glob.glob(resume_checkpoint.replace("LATEST","*.pt")) if not files: return None return max(files,key=parse_resume_step_from_filename) return resume_checkpoint def find_ema_checkpoint(main_checkpoint, step, rate): if main_checkpoint is None: return None filename = f"ema_{rate}_{(step):06d}.pt" path = bf.join(bf.dirname(main_checkpoint), filename) if bf.exists(path): return path return None def log_loss_dict(diffusion, ts, losses): for key, values in losses.items(): logger.logkv_mean(key, values.mean().item()) # Log the quantiles (four quartiles, in particular). for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): quartile = int(4 * sub_t / diffusion.num_timesteps) logger.logkv_mean(f"{key}_q{quartile}", sub_loss)