import torch import tops import numpy as np from sg3_torch_utils.ops import conv2d_gradfix pl_mean_total = torch.zeros([]) class PLRegularization: def __init__(self, weight: float, batch_shrink: int, pl_decay: float, scale_by_mask: bool, **kwargs): self.pl_mean = torch.zeros([], device=tops.get_device()) self.pl_weight = weight self.batch_shrink = batch_shrink self.pl_decay = pl_decay self.scale_by_mask = scale_by_mask def __call__(self, G, batch, grad_scaler): batch_size = batch["img"].shape[0] // self.batch_shrink batch = {k: v[:batch_size] for k, v in batch.items() if k != "embed_map"} if "embed_map" in batch: batch["embed_map"] = batch["embed_map"] z = G.get_z(batch["img"]) with torch.cuda.amp.autocast(tops.AMP()): gen_ws = G.style_net(z) gen_img = G(**batch, w=gen_ws)["img"].float() pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) with conv2d_gradfix.no_weight_gradients(): # Sums over HWC pl_grads = torch.autograd.grad( outputs=[grad_scaler.scale(gen_img * pl_noise)], inputs=[gen_ws], create_graph=True, grad_outputs=torch.ones_like(gen_img), only_inputs=True)[0] pl_grads = pl_grads.float() / grad_scaler.get_scale() if self.scale_by_mask: # Percentage of pixels known scaling = batch["mask"].flatten(start_dim=1).mean(dim=1).view(-1, 1) pl_grads = pl_grads / scaling pl_lengths = pl_grads.square().sum(1).sqrt() pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) if not torch.isnan(pl_mean).any(): self.pl_mean.copy_(pl_mean.detach()) pl_penalty = (pl_lengths - pl_mean).square() to_log = dict(pl_penalty=pl_penalty.mean().detach()) return pl_penalty.view(-1) * self.pl_weight, to_log