Spaces:
Runtime error
Runtime error
import functools | |
import torch | |
import tops | |
from tops import logger | |
from dp2.utils import forward_D_fake | |
from .utils import nsgan_d_loss, nsgan_g_loss | |
from .r1_regularization import r1_regularization | |
from .pl_regularization import PLRegularization | |
class StyleGAN2Loss: | |
def __init__( | |
self, | |
D, | |
G, | |
r1_opts: dict, | |
EP_lambd: float, | |
lazy_reg_interval: int, | |
lazy_regularization: bool, | |
pl_reg_opts: dict, | |
) -> None: | |
self.gradient_step_D = 0 | |
self._lazy_reg_interval = lazy_reg_interval | |
self.D = D | |
self.G = G | |
self.EP_lambd = EP_lambd | |
self.lazy_regularization = lazy_regularization | |
self.r1_reg = functools.partial( | |
r1_regularization, **r1_opts, lazy_reg_interval=lazy_reg_interval, | |
lazy_regularization=lazy_regularization) | |
self.do_PL_Reg = False | |
if pl_reg_opts.weight > 0: | |
self.pl_reg = PLRegularization(**pl_reg_opts) | |
self.do_PL_Reg = True | |
self.pl_start_nimg = pl_reg_opts.start_nimg | |
def D_loss(self, batch: dict, grad_scaler): | |
to_log = {} | |
# Forward through G and D | |
do_GP = self.lazy_regularization and self.gradient_step_D % self._lazy_reg_interval == 0 | |
if do_GP: | |
batch["img"] = batch["img"].detach().requires_grad_(True) | |
with torch.cuda.amp.autocast(enabled=tops.AMP()): | |
with torch.no_grad(): | |
G_fake = self.G(**batch, update_emas=True) | |
D_out_real = self.D(**batch) | |
D_out_fake = forward_D_fake(batch, G_fake["img"], self.D) | |
# Non saturating loss | |
nsgan_loss = nsgan_d_loss(D_out_real["score"], D_out_fake["score"]) | |
tops.assert_shape(nsgan_loss, (batch["img"].shape[0], )) | |
to_log["d_loss"] = nsgan_loss.mean() | |
total_loss = nsgan_loss | |
epsilon_penalty = D_out_real["score"].pow(2).view(-1) | |
to_log["epsilon_penalty"] = epsilon_penalty.mean() | |
tops.assert_shape(epsilon_penalty, total_loss.shape) | |
total_loss = total_loss + epsilon_penalty * self.EP_lambd | |
# Improved gradient penalty with lazy regularization | |
# Gradient penalty applies specialized autocast. | |
if do_GP: | |
gradient_pen, grad_unscaled = self.r1_reg( | |
batch["img"], D_out_real["score"], batch["mask"], scaler=grad_scaler) | |
to_log["r1_gradient_penalty"] = grad_unscaled.mean() | |
tops.assert_shape(gradient_pen, total_loss.shape) | |
total_loss = total_loss + gradient_pen | |
batch["img"] = batch["img"].detach().requires_grad_(False) | |
if "score" in D_out_real: | |
to_log["real_scores"] = D_out_real["score"] | |
to_log["real_logits_sign"] = D_out_real["score"].sign() | |
to_log["fake_logits_sign"] = D_out_fake["score"].sign() | |
to_log["fake_scores"] = D_out_fake["score"] | |
to_log = {key: item.mean().detach() for key, item in to_log.items()} | |
self.gradient_step_D += 1 | |
return total_loss.mean(), to_log | |
def G_loss(self, batch: dict, grad_scaler): | |
with torch.cuda.amp.autocast(enabled=tops.AMP()): | |
to_log = {} | |
# Forward through G and D | |
G_fake = self.G(**batch) | |
D_out_fake = forward_D_fake(batch, G_fake["img"], self.D) | |
# Adversarial Loss | |
total_loss = nsgan_g_loss(D_out_fake["score"]).view(-1) | |
to_log["g_loss"] = total_loss.mean() | |
tops.assert_shape(total_loss, (batch["img"].shape[0], )) | |
if self.do_PL_Reg and logger.global_step() >= self.pl_start_nimg: | |
pl_reg, to_log_ = self.pl_reg(self.G, batch, grad_scaler=grad_scaler) | |
total_loss = total_loss + pl_reg.mean() | |
to_log.update(to_log_) | |
to_log = {key: item.mean().detach() for key, item in to_log.items()} | |
return total_loss.mean(), to_log | |