haakohu's picture
initial
5d756f1
raw
history blame
4.01 kB
import functools
import torch
import tops
from tops import logger
from dp2.utils import forward_D_fake
from .utils import nsgan_d_loss, nsgan_g_loss
from .r1_regularization import r1_regularization
from .pl_regularization import PLRegularization
class StyleGAN2Loss:
def __init__(
self,
D,
G,
r1_opts: dict,
EP_lambd: float,
lazy_reg_interval: int,
lazy_regularization: bool,
pl_reg_opts: dict,
) -> None:
self.gradient_step_D = 0
self._lazy_reg_interval = lazy_reg_interval
self.D = D
self.G = G
self.EP_lambd = EP_lambd
self.lazy_regularization = lazy_regularization
self.r1_reg = functools.partial(
r1_regularization, **r1_opts, lazy_reg_interval=lazy_reg_interval,
lazy_regularization=lazy_regularization)
self.do_PL_Reg = False
if pl_reg_opts.weight > 0:
self.pl_reg = PLRegularization(**pl_reg_opts)
self.do_PL_Reg = True
self.pl_start_nimg = pl_reg_opts.start_nimg
def D_loss(self, batch: dict, grad_scaler):
to_log = {}
# Forward through G and D
do_GP = self.lazy_regularization and self.gradient_step_D % self._lazy_reg_interval == 0
if do_GP:
batch["img"] = batch["img"].detach().requires_grad_(True)
with torch.cuda.amp.autocast(enabled=tops.AMP()):
with torch.no_grad():
G_fake = self.G(**batch, update_emas=True)
D_out_real = self.D(**batch)
D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
# Non saturating loss
nsgan_loss = nsgan_d_loss(D_out_real["score"], D_out_fake["score"])
tops.assert_shape(nsgan_loss, (batch["img"].shape[0], ))
to_log["d_loss"] = nsgan_loss.mean()
total_loss = nsgan_loss
epsilon_penalty = D_out_real["score"].pow(2).view(-1)
to_log["epsilon_penalty"] = epsilon_penalty.mean()
tops.assert_shape(epsilon_penalty, total_loss.shape)
total_loss = total_loss + epsilon_penalty * self.EP_lambd
# Improved gradient penalty with lazy regularization
# Gradient penalty applies specialized autocast.
if do_GP:
gradient_pen, grad_unscaled = self.r1_reg(
batch["img"], D_out_real["score"], batch["mask"], scaler=grad_scaler)
to_log["r1_gradient_penalty"] = grad_unscaled.mean()
tops.assert_shape(gradient_pen, total_loss.shape)
total_loss = total_loss + gradient_pen
batch["img"] = batch["img"].detach().requires_grad_(False)
if "score" in D_out_real:
to_log["real_scores"] = D_out_real["score"]
to_log["real_logits_sign"] = D_out_real["score"].sign()
to_log["fake_logits_sign"] = D_out_fake["score"].sign()
to_log["fake_scores"] = D_out_fake["score"]
to_log = {key: item.mean().detach() for key, item in to_log.items()}
self.gradient_step_D += 1
return total_loss.mean(), to_log
def G_loss(self, batch: dict, grad_scaler):
with torch.cuda.amp.autocast(enabled=tops.AMP()):
to_log = {}
# Forward through G and D
G_fake = self.G(**batch)
D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
# Adversarial Loss
total_loss = nsgan_g_loss(D_out_fake["score"]).view(-1)
to_log["g_loss"] = total_loss.mean()
tops.assert_shape(total_loss, (batch["img"].shape[0], ))
if self.do_PL_Reg and logger.global_step() >= self.pl_start_nimg:
pl_reg, to_log_ = self.pl_reg(self.G, batch, grad_scaler=grad_scaler)
total_loss = total_loss + pl_reg.mean()
to_log.update(to_log_)
to_log = {key: item.mean().detach() for key, item in to_log.items()}
return total_loss.mean(), to_log