EnlightenGAN / models /networks.py
HenryGong's picture
Upload 84 files
aba0e05 verified
import torch
import os
import math
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
# from torch.utils.serialization import load_lua
from lib.nn import SynchronizedBatchNorm2d as SynBN2d
###############################################################################
# Functions
###############################################################################
def pad_tensor(input):
height_org, width_org = input.shape[2], input.shape[3]
divide = 16
if width_org % divide != 0 or height_org % divide != 0:
width_res = width_org % divide
height_res = height_org % divide
if width_res != 0:
width_div = divide - width_res
pad_left = int(width_div / 2)
pad_right = int(width_div - pad_left)
else:
pad_left = 0
pad_right = 0
if height_res != 0:
height_div = divide - height_res
pad_top = int(height_div / 2)
pad_bottom = int(height_div - pad_top)
else:
pad_top = 0
pad_bottom = 0
padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom))
input = padding(input)
else:
pad_left = 0
pad_right = 0
pad_top = 0
pad_bottom = 0
height, width = input.data.shape[2], input.data.shape[3]
assert width % divide == 0, 'width cant divided by stride'
assert height % divide == 0, 'height cant divided by stride'
return input, pad_left, pad_right, pad_top, pad_bottom
def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom):
height, width = input.shape[2], input.shape[3]
return input[:,:, pad_top: height - pad_bottom, pad_left: width - pad_right]
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == 'synBN':
norm_layer = functools.partial(SynBN2d, affine=True)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm)
return norm_layer
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], skip=False, opt=None):
netG = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert(torch.cuda.is_available())
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_128':
netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_256':
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt)
elif which_model_netG == 'unet_512':
netG = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt)
elif which_model_netG == 'sid_unet':
netG = Unet(opt, skip)
elif which_model_netG == 'sid_unet_shuffle':
netG = Unet_pixelshuffle(opt, skip)
elif which_model_netG == 'sid_unet_resize':
netG = Unet_resize_conv(opt, skip)
elif which_model_netG == 'DnCNN':
netG = DnCNN(opt, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) >= 0:
netG.cuda(device=gpu_ids[0])
netG = torch.nn.DataParallel(netG, gpu_ids)
netG.apply(weights_init)
return netG
def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], patch=False):
netD = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert(torch.cuda.is_available())
if which_model_netD == 'basic':
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'no_norm':
netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'no_norm_4':
netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'no_patchgan':
netD = FCDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, patch=patch)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
if use_gpu:
netD.cuda(device=gpu_ids[0])
netD = torch.nn.DataParallel(netD, gpu_ids)
netD.apply(weights_init)
return netD
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
##############################################################################
# Classes
##############################################################################
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
class DiscLossWGANGP():
def __init__(self):
self.LAMBDA = 10
def name(self):
return 'DiscLossWGAN-GP'
def initialize(self, opt, tensor):
# DiscLossLS.initialize(self, opt, tensor)
self.LAMBDA = 10
# def get_g_loss(self, net, realA, fakeB):
# # First, G(A) should fake the discriminator
# self.D_fake = net.forward(fakeB)
# return -self.D_fake.mean()
def calc_gradient_penalty(self, netD, real_data, fake_data):
alpha = torch.rand(1, 1)
alpha = alpha.expand(real_data.size())
alpha = alpha.cuda()
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
interpolates = interpolates.cuda()
interpolates = Variable(interpolates, requires_grad=True)
disc_interpolates = netD.forward(interpolates)
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
return gradient_penalty
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.gpu_ids = gpu_ids
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], skip=False, opt=None):
super(UnetGenerator, self).__init__()
self.gpu_ids = gpu_ids
self.opt = opt
# currently support only input_nc == output_nc
assert(input_nc == output_nc)
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True, opt=opt)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout, opt=opt)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer, opt=opt)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer, opt=opt)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer, opt=opt)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer, opt=opt)
if skip == True:
skipmodule = SkipModule(unet_block, opt)
self.model = skipmodule
else:
self.model = unet_block
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
class SkipModule(nn.Module):
def __init__(self, submodule, opt):
super(SkipModule, self).__init__()
self.submodule = submodule
self.opt = opt
def forward(self, x):
latent = self.submodule(x)
return self.opt.skip*x + latent, latent
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, opt=None):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
stride=2, padding=1)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if opt.use_norm == 0:
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
else:
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([self.model(x), x], 1)
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
# if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
# return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
# else:
return self.model(input)
class NoNormDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]):
super(NoNormDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
# if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
# return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
# else:
return self.model(input)
class FCDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[], patch=False):
super(FCDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
self.use_sigmoid = use_sigmoid
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if patch:
self.linear = nn.Linear(7*7,1)
else:
self.linear = nn.Linear(13*13,1)
if use_sigmoid:
self.sigmoid = nn.Sigmoid()
self.model = nn.Sequential(*sequence)
def forward(self, input):
batchsize = input.size()[0]
output = self.model(input)
output = output.view(batchsize,-1)
# print(output.size())
output = self.linear(output)
if self.use_sigmoid:
print("sigmoid")
output = self.sigmoid(output)
return output
class Unet_resize_conv(nn.Module):
def __init__(self, opt, skip):
super(Unet_resize_conv, self).__init__()
self.opt = opt
self.skip = skip
p = 1
# self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p)
if opt.self_attention:
self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p)
# self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p)
self.downsample_1 = nn.MaxPool2d(2)
self.downsample_2 = nn.MaxPool2d(2)
self.downsample_3 = nn.MaxPool2d(2)
self.downsample_4 = nn.MaxPool2d(2)
else:
self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p)
self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn1_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p)
self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn1_2 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
self.max_pool1 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p)
self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn2_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p)
self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn2_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
self.max_pool2 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p)
self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn3_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p)
self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn3_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
self.max_pool3 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
self.conv4_1 = nn.Conv2d(128, 256, 3, padding=p)
self.LReLU4_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn4_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
self.conv4_2 = nn.Conv2d(256, 256, 3, padding=p)
self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn4_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
self.max_pool4 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p)
self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn5_1 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512)
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p)
self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn5_2 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512)
# self.deconv5 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.deconv5 = nn.Conv2d(512, 256, 3, padding=p)
self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p)
self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn6_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p)
self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn6_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
# self.deconv6 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.deconv6 = nn.Conv2d(256, 128, 3, padding=p)
self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p)
self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn7_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p)
self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn7_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
# self.deconv7 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.deconv7 = nn.Conv2d(128, 64, 3, padding=p)
self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p)
self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn8_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p)
self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn8_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
# self.deconv8 = nn.ConvTranspose2d(64, 32, 2, stride=2)
self.deconv8 = nn.Conv2d(64, 32, 3, padding=p)
self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p)
self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn9_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p)
self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True)
self.conv10 = nn.Conv2d(32, 3, 1)
if self.opt.tanh:
self.tanh = nn.Tanh()
def depth_to_space(self, input, block_size):
block_size_sq = block_size*block_size
output = input.permute(0, 2, 3, 1)
(batch_size, d_height, d_width, d_depth) = output.size()
s_depth = int(d_depth / block_size_sq)
s_width = int(d_width * block_size)
s_height = int(d_height * block_size)
t_1 = output.resize(batch_size, d_height, d_width, block_size_sq, s_depth)
spl = t_1.split(block_size, 3)
stack = [t_t.resize(batch_size, d_height, s_width, s_depth) for t_t in spl]
output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).resize(batch_size, s_height, s_width, s_depth)
output = output.permute(0, 3, 1, 2)
return output
def forward(self, input, gray):
flag = 0
if input.size()[3] > 2200:
avg = nn.AvgPool2d(2)
input = avg(input)
gray = avg(gray)
flag = 1
# pass
input, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(input)
gray, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(gray)
if self.opt.self_attention:
gray_2 = self.downsample_1(gray)
gray_3 = self.downsample_2(gray_2)
gray_4 = self.downsample_3(gray_3)
gray_5 = self.downsample_4(gray_4)
if self.opt.use_norm == 1:
if self.opt.self_attention:
x = self.bn1_1(self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1))))
# x = self.bn1_1(self.LReLU1_1(self.conv1_1(input)))
else:
x = self.bn1_1(self.LReLU1_1(self.conv1_1(input)))
conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x)))
x = self.max_pool1(conv1)
x = self.bn2_1(self.LReLU2_1(self.conv2_1(x)))
conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x)))
x = self.max_pool2(conv2)
x = self.bn3_1(self.LReLU3_1(self.conv3_1(x)))
conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x)))
x = self.max_pool3(conv3)
x = self.bn4_1(self.LReLU4_1(self.conv4_1(x)))
conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x)))
x = self.max_pool4(conv4)
x = self.bn5_1(self.LReLU5_1(self.conv5_1(x)))
x = x*gray_5 if self.opt.self_attention else x
conv5 = self.bn5_2(self.LReLU5_2(self.conv5_2(x)))
conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear')
conv4 = conv4*gray_4 if self.opt.self_attention else conv4
up6 = torch.cat([self.deconv5(conv5), conv4], 1)
x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6)))
conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x)))
conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear')
conv3 = conv3*gray_3 if self.opt.self_attention else conv3
up7 = torch.cat([self.deconv6(conv6), conv3], 1)
x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7)))
conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x)))
conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear')
conv2 = conv2*gray_2 if self.opt.self_attention else conv2
up8 = torch.cat([self.deconv7(conv7), conv2], 1)
x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8)))
conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x)))
conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear')
conv1 = conv1*gray if self.opt.self_attention else conv1
up9 = torch.cat([self.deconv8(conv8), conv1], 1)
x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9)))
conv9 = self.LReLU9_2(self.conv9_2(x))
latent = self.conv10(conv9)
if self.opt.times_residual:
latent = latent*gray
# output = self.depth_to_space(conv10, 2)
if self.opt.tanh:
latent = self.tanh(latent)
if self.skip:
if self.opt.linear_add:
if self.opt.latent_threshold:
latent = F.relu(latent)
elif self.opt.latent_norm:
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
input = (input - torch.min(input))/(torch.max(input) - torch.min(input))
output = latent + input*self.opt.skip
output = output*2 - 1
else:
if self.opt.latent_threshold:
latent = F.relu(latent)
elif self.opt.latent_norm:
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
output = latent + input*self.opt.skip
else:
output = latent
if self.opt.linear:
output = output/torch.max(torch.abs(output))
elif self.opt.use_norm == 0:
if self.opt.self_attention:
x = self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1)))
else:
x = self.LReLU1_1(self.conv1_1(input))
conv1 = self.LReLU1_2(self.conv1_2(x))
x = self.max_pool1(conv1)
x = self.LReLU2_1(self.conv2_1(x))
conv2 = self.LReLU2_2(self.conv2_2(x))
x = self.max_pool2(conv2)
x = self.LReLU3_1(self.conv3_1(x))
conv3 = self.LReLU3_2(self.conv3_2(x))
x = self.max_pool3(conv3)
x = self.LReLU4_1(self.conv4_1(x))
conv4 = self.LReLU4_2(self.conv4_2(x))
x = self.max_pool4(conv4)
x = self.LReLU5_1(self.conv5_1(x))
x = x*gray_5 if self.opt.self_attention else x
conv5 = self.LReLU5_2(self.conv5_2(x))
conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear')
conv4 = conv4*gray_4 if self.opt.self_attention else conv4
up6 = torch.cat([self.deconv5(conv5), conv4], 1)
x = self.LReLU6_1(self.conv6_1(up6))
conv6 = self.LReLU6_2(self.conv6_2(x))
conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear')
conv3 = conv3*gray_3 if self.opt.self_attention else conv3
up7 = torch.cat([self.deconv6(conv6), conv3], 1)
x = self.LReLU7_1(self.conv7_1(up7))
conv7 = self.LReLU7_2(self.conv7_2(x))
conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear')
conv2 = conv2*gray_2 if self.opt.self_attention else conv2
up8 = torch.cat([self.deconv7(conv7), conv2], 1)
x = self.LReLU8_1(self.conv8_1(up8))
conv8 = self.LReLU8_2(self.conv8_2(x))
conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear')
conv1 = conv1*gray if self.opt.self_attention else conv1
up9 = torch.cat([self.deconv8(conv8), conv1], 1)
x = self.LReLU9_1(self.conv9_1(up9))
conv9 = self.LReLU9_2(self.conv9_2(x))
latent = self.conv10(conv9)
if self.opt.times_residual:
latent = latent*gray
if self.opt.tanh:
latent = self.tanh(latent)
if self.skip:
if self.opt.linear_add:
if self.opt.latent_threshold:
latent = F.relu(latent)
elif self.opt.latent_norm:
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
input = (input - torch.min(input))/(torch.max(input) - torch.min(input))
output = latent + input*self.opt.skip
output = output*2 - 1
else:
if self.opt.latent_threshold:
latent = F.relu(latent)
elif self.opt.latent_norm:
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
output = latent + input*self.opt.skip
else:
output = latent
if self.opt.linear:
output = output/torch.max(torch.abs(output))
output = pad_tensor_back(output, pad_left, pad_right, pad_top, pad_bottom)
latent = pad_tensor_back(latent, pad_left, pad_right, pad_top, pad_bottom)
gray = pad_tensor_back(gray, pad_left, pad_right, pad_top, pad_bottom)
if flag == 1:
output = F.upsample(output, scale_factor=2, mode='bilinear')
gray = F.upsample(gray, scale_factor=2, mode='bilinear')
if self.skip:
return output, latent
else:
return output
class DnCNN(nn.Module):
def __init__(self, opt=None, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
super(DnCNN, self).__init__()
kernel_size = 3
padding = 1
layers = []
layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
layers.append(nn.ReLU(inplace=True))
for _ in range(depth-2):
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
self._initialize_weights()
def forward(self, x):
y = x
out = self.dncnn(x)
return y+out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.orthogonal_(m.weight)
print('init weight')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
class Vgg16(nn.Module):
def __init__(self):
super(Vgg16, self).__init__()
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
def forward(self, X, opt):
h = F.relu(self.conv1_1(X), inplace=True)
h = F.relu(self.conv1_2(h), inplace=True)
# relu1_2 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv2_1(h), inplace=True)
h = F.relu(self.conv2_2(h), inplace=True)
# relu2_2 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv3_1(h), inplace=True)
h = F.relu(self.conv3_2(h), inplace=True)
h = F.relu(self.conv3_3(h), inplace=True)
# relu3_3 = h
if opt.vgg_choose != "no_maxpool":
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv4_1(h), inplace=True)
relu4_1 = h
h = F.relu(self.conv4_2(h), inplace=True)
relu4_2 = h
conv4_3 = self.conv4_3(h)
h = F.relu(conv4_3, inplace=True)
relu4_3 = h
if opt.vgg_choose != "no_maxpool":
if opt.vgg_maxpooling:
h = F.max_pool2d(h, kernel_size=2, stride=2)
relu5_1 = F.relu(self.conv5_1(h), inplace=True)
relu5_2 = F.relu(self.conv5_2(relu5_1), inplace=True)
conv5_3 = self.conv5_3(relu5_2)
h = F.relu(conv5_3, inplace=True)
relu5_3 = h
if opt.vgg_choose == "conv4_3":
return conv4_3
elif opt.vgg_choose == "relu4_2":
return relu4_2
elif opt.vgg_choose == "relu4_1":
return relu4_1
elif opt.vgg_choose == "relu4_3":
return relu4_3
elif opt.vgg_choose == "conv5_3":
return conv5_3
elif opt.vgg_choose == "relu5_1":
return relu5_1
elif opt.vgg_choose == "relu5_2":
return relu5_2
elif opt.vgg_choose == "relu5_3" or "maxpool":
return relu5_3
def vgg_preprocess(batch, opt):
tensortype = type(batch.data)
(r, g, b) = torch.chunk(batch, 3, dim = 1)
batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR
batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255]
if opt.vgg_mean:
mean = tensortype(batch.data.size())
mean[:, 0, :, :] = 103.939
mean[:, 1, :, :] = 116.779
mean[:, 2, :, :] = 123.680
batch = batch.sub(Variable(mean)) # subtract mean
return batch
class PerceptualLoss(nn.Module):
def __init__(self, opt):
super(PerceptualLoss, self).__init__()
self.opt = opt
self.instancenorm = nn.InstanceNorm2d(512, affine=False)
def compute_vgg_loss(self, vgg, img, target):
img_vgg = vgg_preprocess(img, self.opt)
target_vgg = vgg_preprocess(target, self.opt)
img_fea = vgg(img_vgg, self.opt)
target_fea = vgg(target_vgg, self.opt)
if self.opt.no_vgg_instance:
return torch.mean((img_fea - target_fea) ** 2)
else:
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)
def load_vgg16(model_dir, gpu_ids):
""" Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """
if not os.path.exists(model_dir):
os.mkdir(model_dir)
# if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')):
# if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')):
# os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7'))
# vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7'))
# vgg = Vgg16()
# for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
# dst.data[:] = src
# torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight'))
vgg = Vgg16()
# vgg.cuda()
vgg.cuda(device=gpu_ids[0])
vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight')))
vgg = torch.nn.DataParallel(vgg, gpu_ids)
return vgg
class FCN32s(nn.Module):
def __init__(self, n_class=21):
super(FCN32s, self).__init__()
# conv1
self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)
self.relu1_1 = nn.ReLU(inplace=True)
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
self.relu1_2 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2
# conv2
self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
self.relu2_1 = nn.ReLU(inplace=True)
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
self.relu2_2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4
# conv3
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
self.relu3_1 = nn.ReLU(inplace=True)
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
self.relu3_2 = nn.ReLU(inplace=True)
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
self.relu3_3 = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8
# conv4
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
self.relu4_1 = nn.ReLU(inplace=True)
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
self.relu4_2 = nn.ReLU(inplace=True)
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
self.relu4_3 = nn.ReLU(inplace=True)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16
# conv5
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_1 = nn.ReLU(inplace=True)
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_2 = nn.ReLU(inplace=True)
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_3 = nn.ReLU(inplace=True)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32
# fc6
self.fc6 = nn.Conv2d(512, 4096, 7)
self.relu6 = nn.ReLU(inplace=True)
self.drop6 = nn.Dropout2d()
# fc7
self.fc7 = nn.Conv2d(4096, 4096, 1)
self.relu7 = nn.ReLU(inplace=True)
self.drop7 = nn.Dropout2d()
self.score_fr = nn.Conv2d(4096, n_class, 1)
self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32,
bias=False)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.zero_()
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.ConvTranspose2d):
assert m.kernel_size[0] == m.kernel_size[1]
initial_weight = get_upsampling_weight(
m.in_channels, m.out_channels, m.kernel_size[0])
m.weight.data.copy_(initial_weight)
def forward(self, x):
h = x
h = self.relu1_1(self.conv1_1(h))
h = self.relu1_2(self.conv1_2(h))
h = self.pool1(h)
h = self.relu2_1(self.conv2_1(h))
h = self.relu2_2(self.conv2_2(h))
h = self.pool2(h)
h = self.relu3_1(self.conv3_1(h))
h = self.relu3_2(self.conv3_2(h))
h = self.relu3_3(self.conv3_3(h))
h = self.pool3(h)
h = self.relu4_1(self.conv4_1(h))
h = self.relu4_2(self.conv4_2(h))
h = self.relu4_3(self.conv4_3(h))
h = self.pool4(h)
h = self.relu5_1(self.conv5_1(h))
h = self.relu5_2(self.conv5_2(h))
h = self.relu5_3(self.conv5_3(h))
h = self.pool5(h)
h = self.relu6(self.fc6(h))
h = self.drop6(h)
h = self.relu7(self.fc7(h))
h = self.drop7(h)
h = self.score_fr(h)
h = self.upscore(h)
h = h[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous()
return h
def load_fcn(model_dir):
fcn = FCN32s()
fcn.load_state_dict(torch.load(os.path.join(model_dir, 'fcn32s_from_caffe.pth')))
fcn.cuda()
return fcn
class SemanticLoss(nn.Module):
def __init__(self, opt):
super(SemanticLoss, self).__init__()
self.opt = opt
self.instancenorm = nn.InstanceNorm2d(21, affine=False)
def compute_fcn_loss(self, fcn, img, target):
img_fcn = vgg_preprocess(img, self.opt)
target_fcn = vgg_preprocess(target, self.opt)
img_fea = fcn(img_fcn)
target_fea = fcn(target_fcn)
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)