Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |
from .nn import mean_flat | |
#from . import dist_util | |
import functools | |
class AdversarialLoss(nn.Module): | |
def __init__(self, gan_type='WGAN_GP', gan_k=1, | |
lr_dis=1e-5 ): | |
super(AdversarialLoss, self).__init__() | |
self.gan_type = gan_type | |
self.gan_k = gan_k | |
model = NLayerDiscriminator().cuda() | |
self.discriminator = DDP( | |
model, | |
device_ids=[torch.device('cuda')], | |
output_device=torch.device('cuda'), | |
broadcast_buffers=False, | |
bucket_cap_mb=128, | |
find_unused_parameters=False, | |
) | |
if (gan_type in ['WGAN_GP', 'GAN']): | |
self.optimizer = optim.Adam( | |
self.discriminator.parameters(), | |
lr=lr_dis | |
) | |
def forward(self, fake, real): | |
fake_detach = fake.detach() | |
for _ in range(self.gan_k): | |
self.optimizer.zero_grad() | |
d_fake = self.discriminator(fake_detach) | |
d_real = self.discriminator(real) | |
if (self.gan_type.find('WGAN') >= 0): | |
loss_d = (d_fake - d_real).mean() | |
if self.gan_type.find('GP') >= 0: | |
epsilon = torch.rand(real.size(0), 1, 1, 1).cuda() | |
epsilon = epsilon.expand(real.size()) | |
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) | |
hat.requires_grad = True | |
d_hat = self.discriminator(hat) | |
gradients = torch.autograd.grad( | |
outputs=d_hat.sum(), inputs=hat, | |
retain_graph=True, create_graph=True, only_inputs=True | |
)[0] | |
gradients = gradients.view(gradients.size(0), -1) | |
gradient_norm = gradients.norm(2, dim=1) | |
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() | |
loss_d += gradient_penalty | |
# print('d loss:', loss_d) | |
# Discriminator update | |
loss_d.backward() | |
self.optimizer.step() | |
d_fake_for_g = self.discriminator(fake) | |
if (self.gan_type.find('WGAN') >= 0): | |
loss_g = -d_fake_for_g | |
# Generator loss | |
return mean_flat(loss_g) | |
def conv3x3(in_channels, out_channels, stride=1): | |
return nn.Conv2d(in_channels, out_channels, kernel_size=3, | |
stride=stride, padding=1, bias=True) | |
def conv7x7(in_channels, out_channels, stride=1): | |
return nn.Conv2d(in_channels, out_channels, kernel_size=7, | |
stride=stride, padding=3, bias=True) | |
class Discriminator(nn.Module): | |
def __init__(self, ): | |
super(Discriminator, self).__init__() | |
self.conv1 = conv7x7(3, 32) | |
self.norm1 = nn.InstanceNorm2d(32, affine=True) | |
self.LReLU1 = nn.LeakyReLU(0.2) | |
self.conv2 = conv3x3(32, 32, 2) | |
self.norm2 = nn.InstanceNorm2d(32, affine=True) | |
self.LReLU2 = nn.LeakyReLU(0.2) | |
self.conv3 = conv3x3(32, 64) | |
self.norm3 = nn.InstanceNorm2d(64, affine=True) | |
self.LReLU3 = nn.LeakyReLU(0.2) | |
self.conv4 = conv3x3(64, 64, 2) | |
self.norm4 = nn.InstanceNorm2d(64, affine=True) | |
self.LReLU4 = nn.LeakyReLU(0.2) | |
self.conv5 = conv3x3(64, 128) | |
self.norm5 = nn.InstanceNorm2d(128, affine=True) | |
self.LReLU5 = nn.LeakyReLU(0.2) | |
self.conv6 = conv3x3(128, 128, 2) | |
self.norm6 = nn.InstanceNorm2d(128, affine=True) | |
self.LReLU6 = nn.LeakyReLU(0.2) | |
self.conv7 = conv3x3(128, 256) | |
self.norm7 = nn.InstanceNorm2d(256, affine=True) | |
self.LReLU7 = nn.LeakyReLU(0.2) | |
self.conv8 = conv3x3(256, 256, 2) | |
self.norm8 = nn.InstanceNorm2d(256, affine=True) | |
self.LReLU8 = nn.LeakyReLU(0.2) | |
self.conv9 = conv3x3(256, 512) | |
self.norm9 = nn.InstanceNorm2d(512, affine=True) | |
self.LReLU9 = nn.LeakyReLU(0.2) | |
self.conv10 = conv3x3(512, 512, 2) | |
self.norm10 = nn.InstanceNorm2d(512, affine=True) | |
self.LReLU10 = nn.LeakyReLU(0.2) | |
self.conv11 = conv3x3(512, 32) | |
self.norm11 = nn.InstanceNorm2d(32, affine=True) | |
self.LReLU11 = nn.LeakyReLU(0.2) | |
self.conv12 = conv3x3(32, 1) | |
def forward(self, x): | |
x = self.LReLU1(self.norm1(self.conv1(x))) | |
x = self.LReLU2(self.norm2(self.conv2(x))) | |
x = self.LReLU3(self.norm3(self.conv3(x))) | |
x = self.LReLU4(self.norm4(self.conv4(x))) | |
x = self.LReLU5(self.norm5(self.conv5(x))) | |
x = self.LReLU6(self.norm6(self.conv6(x))) | |
x = self.LReLU7(self.norm7(self.conv7(x))) | |
x = self.LReLU8(self.norm8(self.conv8(x))) | |
x = self.LReLU9(self.norm9(self.conv9(x))) | |
x = self.LReLU10(self.norm10(self.conv10(x))) | |
x = self.LReLU11(self.norm11(self.conv11(x))) | |
x = self.conv12(x) | |
return x | |
def get_norm_layer(norm_type='instance'): | |
"""Return a normalization layer | |
Parameters: | |
norm_type (str) -- the name of the normalization layer: batch | instance | none | |
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). | |
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. | |
""" | |
if norm_type == 'batch': | |
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) | |
elif norm_type == 'instance': | |
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) | |
elif norm_type == 'none': | |
def norm_layer(x): return Identity() | |
else: | |
raise NotImplementedError('normalization layer [%s] is not found' % norm_type) | |
return norm_layer | |
class NLayerDiscriminator(nn.Module): | |
"""Defines a PatchGAN discriminator""" | |
def __init__(self, input_nc=3, ndf=64, n_layers=3 ): | |
"""Construct a PatchGAN discriminator | |
Parameters: | |
input_nc (int) -- the number of channels in input images | |
ndf (int) -- the number of filters in the last conv layer | |
n_layers (int) -- the number of conv layers in the discriminator | |
norm_layer -- normalization layer | |
""" | |
super(NLayerDiscriminator, self).__init__() | |
norm_layer = get_norm_layer(norm_type='instance') | |
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters | |
use_bias = norm_layer.func == nn.InstanceNorm2d | |
else: | |
use_bias = norm_layer == nn.InstanceNorm2d | |
kw = 4 | |
padw = 1 | |
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] | |
nf_mult = 1 | |
nf_mult_prev = 1 | |
for n in range(1, n_layers): # gradually increase the number of filters | |
nf_mult_prev = nf_mult | |
nf_mult = min(2 ** n, 8) | |
sequence += [ | |
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), | |
norm_layer(ndf * nf_mult), | |
nn.LeakyReLU(0.2, True) | |
] | |
nf_mult_prev = nf_mult | |
nf_mult = min(2 ** n_layers, 8) | |
sequence += [ | |
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), | |
norm_layer(ndf * nf_mult), | |
nn.LeakyReLU(0.2, True) | |
] | |
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map | |
self.model = nn.Sequential(*sequence) | |
def forward(self, input): | |
"""Standard forward.""" | |
return self.model(input) |