import math import os import torch from torch import optim from torch.nn import functional as FF from torchvision import transforms from PIL import Image from tqdm import tqdm import dataclasses from .lpips import util def noise_regularize(noises): loss = 0 for noise in noises: size = noise.shape[2] while True: loss = ( loss + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) ) if size <= 8: break noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) noise = noise.mean([3, 5]) size //= 2 return loss def noise_normalize_(noises): for noise in noises: mean = noise.mean() std = noise.std() noise.data.add_(-mean).div_(std) def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): lr_ramp = min(1, (1 - t) / rampdown) lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) lr_ramp = lr_ramp * min(1, t / rampup) return initial_lr * lr_ramp def latent_noise(latent, strength): noise = torch.randn_like(latent) * strength return latent + noise def make_image(tensor): return ( tensor.detach() .clamp_(min=-1, max=1) .add(1) .div_(2) .mul(255) .type(torch.uint8) .permute(0, 2, 3, 1) .to("cpu") .numpy() ) @dataclasses.dataclass class InverseConfig: lr_warmup = 0.05 lr_decay = 0.25 lr = 0.1 noise = 0.05 noise_decay = 0.75 step = 1000 noise_regularize = 1e5 mse = 0 w_plus = False, def inverse_image( g_ema, image, image_size=256, config=InverseConfig() ): device = "cuda" args = config n_mean_latent = 10000 resize = min(image_size, 256) transform = transforms.Compose( [ transforms.Resize(resize), transforms.CenterCrop(resize), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) imgs = [] img = transform(image) imgs.append(img) imgs = torch.stack(imgs, 0).to(device) with torch.no_grad(): noise_sample = torch.randn(n_mean_latent, 512, device=device) latent_out = g_ema.style(noise_sample) latent_mean = latent_out.mean(0) latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 percept = util.PerceptualLoss( model="net-lin", net="vgg", use_gpu=device.startswith("cuda") ) noises_single = g_ema.make_noise() noises = [] for noise in noises_single: noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1) if args.w_plus: latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) latent_in.requires_grad = True for noise in noises: noise.requires_grad = True optimizer = optim.Adam([latent_in] + noises, lr=args.lr) pbar = tqdm(range(args.step)) latent_path = [] for i in pbar: t = i / args.step lr = get_lr(t, args.lr) optimizer.param_groups[0]["lr"] = lr noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2 latent_n = latent_noise(latent_in, noise_strength.item()) latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises) img_gen, F = g_ema.generate(latent, noise) batch, channel, height, width = img_gen.shape if height > 256: factor = height // 256 img_gen = img_gen.reshape( batch, channel, height // factor, factor, width // factor, factor ) img_gen = img_gen.mean([3, 5]) p_loss = percept(img_gen, imgs).sum() n_loss = noise_regularize(noises) mse_loss = FF.mse_loss(img_gen, imgs) loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss optimizer.zero_grad() loss.backward() optimizer.step() noise_normalize_(noises) if (i + 1) % 100 == 0: latent_path.append(latent_in.detach().clone()) pbar.set_description( ( f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};" f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" ) ) latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises) img_gen, F = g_ema.generate(latent, noise) img_ar = make_image(img_gen) i = 0 noise_single = [] for noise in noises: noise_single.append(noise[i: i + 1]) result = { "latent": latent, "noise": noise_single, 'F': F, "sample": img_gen, } pil_img = Image.fromarray(img_ar[i]) pil_img.save('project.png') return result