|
import torch |
|
import os |
|
import math |
|
import torch.nn as nn |
|
from torch.nn import init |
|
import functools |
|
from torch.autograd import Variable |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
from lib.nn import SynchronizedBatchNorm2d as SynBN2d |
|
|
|
|
|
|
|
|
|
def pad_tensor(input): |
|
|
|
height_org, width_org = input.shape[2], input.shape[3] |
|
divide = 16 |
|
|
|
if width_org % divide != 0 or height_org % divide != 0: |
|
|
|
width_res = width_org % divide |
|
height_res = height_org % divide |
|
if width_res != 0: |
|
width_div = divide - width_res |
|
pad_left = int(width_div / 2) |
|
pad_right = int(width_div - pad_left) |
|
else: |
|
pad_left = 0 |
|
pad_right = 0 |
|
|
|
if height_res != 0: |
|
height_div = divide - height_res |
|
pad_top = int(height_div / 2) |
|
pad_bottom = int(height_div - pad_top) |
|
else: |
|
pad_top = 0 |
|
pad_bottom = 0 |
|
|
|
padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom)) |
|
input = padding(input) |
|
else: |
|
pad_left = 0 |
|
pad_right = 0 |
|
pad_top = 0 |
|
pad_bottom = 0 |
|
|
|
height, width = input.data.shape[2], input.data.shape[3] |
|
assert width % divide == 0, 'width cant divided by stride' |
|
assert height % divide == 0, 'height cant divided by stride' |
|
|
|
return input, pad_left, pad_right, pad_top, pad_bottom |
|
|
|
def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom): |
|
height, width = input.shape[2], input.shape[3] |
|
return input[:,:, pad_top: height - pad_bottom, pad_left: width - pad_right] |
|
|
|
def weights_init(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('Conv') != -1: |
|
m.weight.data.normal_(0.0, 0.02) |
|
elif classname.find('BatchNorm2d') != -1: |
|
m.weight.data.normal_(1.0, 0.02) |
|
m.bias.data.fill_(0) |
|
|
|
|
|
def get_norm_layer(norm_type='instance'): |
|
if norm_type == 'batch': |
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True) |
|
elif norm_type == 'instance': |
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) |
|
elif norm_type == 'synBN': |
|
norm_layer = functools.partial(SynBN2d, affine=True) |
|
else: |
|
raise NotImplementedError('normalization layer [%s] is not found' % norm) |
|
return norm_layer |
|
|
|
|
|
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], skip=False, opt=None): |
|
netG = None |
|
use_gpu = len(gpu_ids) > 0 |
|
norm_layer = get_norm_layer(norm_type=norm) |
|
|
|
if use_gpu: |
|
assert(torch.cuda.is_available()) |
|
|
|
if which_model_netG == 'resnet_9blocks': |
|
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) |
|
elif which_model_netG == 'resnet_6blocks': |
|
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) |
|
elif which_model_netG == 'unet_128': |
|
netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) |
|
elif which_model_netG == 'unet_256': |
|
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt) |
|
elif which_model_netG == 'unet_512': |
|
netG = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt) |
|
elif which_model_netG == 'sid_unet': |
|
netG = Unet(opt, skip) |
|
elif which_model_netG == 'sid_unet_shuffle': |
|
netG = Unet_pixelshuffle(opt, skip) |
|
elif which_model_netG == 'sid_unet_resize': |
|
netG = Unet_resize_conv(opt, skip) |
|
elif which_model_netG == 'DnCNN': |
|
netG = DnCNN(opt, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3) |
|
else: |
|
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) |
|
if len(gpu_ids) >= 0: |
|
netG.cuda(device=gpu_ids[0]) |
|
netG = torch.nn.DataParallel(netG, gpu_ids) |
|
netG.apply(weights_init) |
|
return netG |
|
|
|
|
|
def define_D(input_nc, ndf, which_model_netD, |
|
n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], patch=False): |
|
netD = None |
|
use_gpu = len(gpu_ids) > 0 |
|
norm_layer = get_norm_layer(norm_type=norm) |
|
|
|
if use_gpu: |
|
assert(torch.cuda.is_available()) |
|
if which_model_netD == 'basic': |
|
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) |
|
elif which_model_netD == 'n_layers': |
|
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) |
|
elif which_model_netD == 'no_norm': |
|
netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) |
|
elif which_model_netD == 'no_norm_4': |
|
netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) |
|
elif which_model_netD == 'no_patchgan': |
|
netD = FCDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, patch=patch) |
|
else: |
|
raise NotImplementedError('Discriminator model name [%s] is not recognized' % |
|
which_model_netD) |
|
if use_gpu: |
|
netD.cuda(device=gpu_ids[0]) |
|
netD = torch.nn.DataParallel(netD, gpu_ids) |
|
netD.apply(weights_init) |
|
return netD |
|
|
|
|
|
def print_network(net): |
|
num_params = 0 |
|
for param in net.parameters(): |
|
num_params += param.numel() |
|
print(net) |
|
print('Total number of parameters: %d' % num_params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GANLoss(nn.Module): |
|
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, |
|
tensor=torch.FloatTensor): |
|
super(GANLoss, self).__init__() |
|
self.real_label = target_real_label |
|
self.fake_label = target_fake_label |
|
self.real_label_var = None |
|
self.fake_label_var = None |
|
self.Tensor = tensor |
|
if use_lsgan: |
|
self.loss = nn.MSELoss() |
|
else: |
|
self.loss = nn.BCELoss() |
|
|
|
def get_target_tensor(self, input, target_is_real): |
|
target_tensor = None |
|
if target_is_real: |
|
create_label = ((self.real_label_var is None) or |
|
(self.real_label_var.numel() != input.numel())) |
|
if create_label: |
|
real_tensor = self.Tensor(input.size()).fill_(self.real_label) |
|
self.real_label_var = Variable(real_tensor, requires_grad=False) |
|
target_tensor = self.real_label_var |
|
else: |
|
create_label = ((self.fake_label_var is None) or |
|
(self.fake_label_var.numel() != input.numel())) |
|
if create_label: |
|
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) |
|
self.fake_label_var = Variable(fake_tensor, requires_grad=False) |
|
target_tensor = self.fake_label_var |
|
return target_tensor |
|
|
|
def __call__(self, input, target_is_real): |
|
target_tensor = self.get_target_tensor(input, target_is_real) |
|
return self.loss(input, target_tensor) |
|
|
|
|
|
|
|
class DiscLossWGANGP(): |
|
def __init__(self): |
|
self.LAMBDA = 10 |
|
|
|
def name(self): |
|
return 'DiscLossWGAN-GP' |
|
|
|
def initialize(self, opt, tensor): |
|
|
|
self.LAMBDA = 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
def calc_gradient_penalty(self, netD, real_data, fake_data): |
|
alpha = torch.rand(1, 1) |
|
alpha = alpha.expand(real_data.size()) |
|
alpha = alpha.cuda() |
|
|
|
interpolates = alpha * real_data + ((1 - alpha) * fake_data) |
|
|
|
interpolates = interpolates.cuda() |
|
interpolates = Variable(interpolates, requires_grad=True) |
|
|
|
disc_interpolates = netD.forward(interpolates) |
|
|
|
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, |
|
grad_outputs=torch.ones(disc_interpolates.size()).cuda(), |
|
create_graph=True, retain_graph=True, only_inputs=True)[0] |
|
|
|
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA |
|
return gradient_penalty |
|
|
|
|
|
|
|
|
|
|
|
class ResnetGenerator(nn.Module): |
|
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'): |
|
assert(n_blocks >= 0) |
|
super(ResnetGenerator, self).__init__() |
|
self.input_nc = input_nc |
|
self.output_nc = output_nc |
|
self.ngf = ngf |
|
self.gpu_ids = gpu_ids |
|
|
|
model = [nn.ReflectionPad2d(3), |
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), |
|
norm_layer(ngf), |
|
nn.ReLU(True)] |
|
|
|
n_downsampling = 2 |
|
for i in range(n_downsampling): |
|
mult = 2**i |
|
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, |
|
stride=2, padding=1), |
|
norm_layer(ngf * mult * 2), |
|
nn.ReLU(True)] |
|
|
|
mult = 2**n_downsampling |
|
for i in range(n_blocks): |
|
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)] |
|
|
|
for i in range(n_downsampling): |
|
mult = 2**(n_downsampling - i) |
|
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), |
|
kernel_size=3, stride=2, |
|
padding=1, output_padding=1), |
|
norm_layer(int(ngf * mult / 2)), |
|
nn.ReLU(True)] |
|
model += [nn.ReflectionPad2d(3)] |
|
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] |
|
model += [nn.Tanh()] |
|
|
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, input): |
|
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): |
|
return nn.parallel.data_parallel(self.model, input, self.gpu_ids) |
|
else: |
|
return self.model(input) |
|
|
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__(self, dim, padding_type, norm_layer, use_dropout): |
|
super(ResnetBlock, self).__init__() |
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout) |
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout): |
|
conv_block = [] |
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
|
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), |
|
norm_layer(dim), |
|
nn.ReLU(True)] |
|
if use_dropout: |
|
conv_block += [nn.Dropout(0.5)] |
|
|
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), |
|
norm_layer(dim)] |
|
|
|
return nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
out = x + self.conv_block(x) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnetGenerator(nn.Module): |
|
def __init__(self, input_nc, output_nc, num_downs, ngf=64, |
|
norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], skip=False, opt=None): |
|
super(UnetGenerator, self).__init__() |
|
self.gpu_ids = gpu_ids |
|
self.opt = opt |
|
|
|
assert(input_nc == output_nc) |
|
|
|
|
|
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True, opt=opt) |
|
for i in range(num_downs - 5): |
|
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout, opt=opt) |
|
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer, opt=opt) |
|
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer, opt=opt) |
|
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer, opt=opt) |
|
unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer, opt=opt) |
|
|
|
if skip == True: |
|
skipmodule = SkipModule(unet_block, opt) |
|
self.model = skipmodule |
|
else: |
|
self.model = unet_block |
|
|
|
def forward(self, input): |
|
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): |
|
return nn.parallel.data_parallel(self.model, input, self.gpu_ids) |
|
else: |
|
return self.model(input) |
|
|
|
class SkipModule(nn.Module): |
|
def __init__(self, submodule, opt): |
|
super(SkipModule, self).__init__() |
|
self.submodule = submodule |
|
self.opt = opt |
|
|
|
def forward(self, x): |
|
latent = self.submodule(x) |
|
return self.opt.skip*x + latent, latent |
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnetSkipConnectionBlock(nn.Module): |
|
def __init__(self, outer_nc, inner_nc, |
|
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, opt=None): |
|
super(UnetSkipConnectionBlock, self).__init__() |
|
self.outermost = outermost |
|
|
|
downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, |
|
stride=2, padding=1) |
|
downrelu = nn.LeakyReLU(0.2, True) |
|
downnorm = norm_layer(inner_nc) |
|
uprelu = nn.ReLU(True) |
|
upnorm = norm_layer(outer_nc) |
|
|
|
if opt.use_norm == 0: |
|
if outermost: |
|
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, |
|
kernel_size=4, stride=2, |
|
padding=1) |
|
down = [downconv] |
|
up = [uprelu, upconv, nn.Tanh()] |
|
model = down + [submodule] + up |
|
elif innermost: |
|
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, |
|
kernel_size=4, stride=2, |
|
padding=1) |
|
down = [downrelu, downconv] |
|
up = [uprelu, upconv] |
|
model = down + up |
|
else: |
|
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, |
|
kernel_size=4, stride=2, |
|
padding=1) |
|
down = [downrelu, downconv] |
|
up = [uprelu, upconv] |
|
|
|
if use_dropout: |
|
model = down + [submodule] + up + [nn.Dropout(0.5)] |
|
else: |
|
model = down + [submodule] + up |
|
else: |
|
if outermost: |
|
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, |
|
kernel_size=4, stride=2, |
|
padding=1) |
|
down = [downconv] |
|
up = [uprelu, upconv, nn.Tanh()] |
|
model = down + [submodule] + up |
|
elif innermost: |
|
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, |
|
kernel_size=4, stride=2, |
|
padding=1) |
|
down = [downrelu, downconv] |
|
up = [uprelu, upconv, upnorm] |
|
model = down + up |
|
else: |
|
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, |
|
kernel_size=4, stride=2, |
|
padding=1) |
|
down = [downrelu, downconv, downnorm] |
|
up = [uprelu, upconv, upnorm] |
|
|
|
if use_dropout: |
|
model = down + [submodule] + up + [nn.Dropout(0.5)] |
|
else: |
|
model = down + [submodule] + up |
|
|
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, x): |
|
if self.outermost: |
|
return self.model(x) |
|
else: |
|
return torch.cat([self.model(x), x], 1) |
|
|
|
|
|
|
|
class NLayerDiscriminator(nn.Module): |
|
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): |
|
super(NLayerDiscriminator, self).__init__() |
|
self.gpu_ids = gpu_ids |
|
|
|
kw = 4 |
|
padw = int(np.ceil((kw-1)/2)) |
|
sequence = [ |
|
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
nf_mult = 1 |
|
nf_mult_prev = 1 |
|
for n in range(1, n_layers): |
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2**n, 8) |
|
sequence += [ |
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, |
|
kernel_size=kw, stride=2, padding=padw), |
|
norm_layer(ndf * nf_mult), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2**n_layers, 8) |
|
sequence += [ |
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, |
|
kernel_size=kw, stride=1, padding=padw), |
|
norm_layer(ndf * nf_mult), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] |
|
|
|
if use_sigmoid: |
|
sequence += [nn.Sigmoid()] |
|
|
|
self.model = nn.Sequential(*sequence) |
|
|
|
def forward(self, input): |
|
|
|
|
|
|
|
return self.model(input) |
|
|
|
class NoNormDiscriminator(nn.Module): |
|
def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]): |
|
super(NoNormDiscriminator, self).__init__() |
|
self.gpu_ids = gpu_ids |
|
|
|
kw = 4 |
|
padw = int(np.ceil((kw-1)/2)) |
|
sequence = [ |
|
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
nf_mult = 1 |
|
nf_mult_prev = 1 |
|
for n in range(1, n_layers): |
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2**n, 8) |
|
sequence += [ |
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, |
|
kernel_size=kw, stride=2, padding=padw), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2**n_layers, 8) |
|
sequence += [ |
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, |
|
kernel_size=kw, stride=1, padding=padw), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] |
|
|
|
if use_sigmoid: |
|
sequence += [nn.Sigmoid()] |
|
|
|
self.model = nn.Sequential(*sequence) |
|
|
|
def forward(self, input): |
|
|
|
|
|
|
|
return self.model(input) |
|
|
|
class FCDiscriminator(nn.Module): |
|
def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[], patch=False): |
|
super(FCDiscriminator, self).__init__() |
|
self.gpu_ids = gpu_ids |
|
self.use_sigmoid = use_sigmoid |
|
kw = 4 |
|
padw = int(np.ceil((kw-1)/2)) |
|
sequence = [ |
|
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
nf_mult = 1 |
|
nf_mult_prev = 1 |
|
for n in range(1, n_layers): |
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2**n, 8) |
|
sequence += [ |
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, |
|
kernel_size=kw, stride=2, padding=padw), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2**n_layers, 8) |
|
sequence += [ |
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, |
|
kernel_size=kw, stride=1, padding=padw), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] |
|
if patch: |
|
self.linear = nn.Linear(7*7,1) |
|
else: |
|
self.linear = nn.Linear(13*13,1) |
|
if use_sigmoid: |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
self.model = nn.Sequential(*sequence) |
|
|
|
def forward(self, input): |
|
batchsize = input.size()[0] |
|
output = self.model(input) |
|
output = output.view(batchsize,-1) |
|
|
|
output = self.linear(output) |
|
if self.use_sigmoid: |
|
print("sigmoid") |
|
output = self.sigmoid(output) |
|
return output |
|
|
|
|
|
class Unet_resize_conv(nn.Module): |
|
def __init__(self, opt, skip): |
|
super(Unet_resize_conv, self).__init__() |
|
|
|
self.opt = opt |
|
self.skip = skip |
|
p = 1 |
|
|
|
if opt.self_attention: |
|
self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p) |
|
|
|
self.downsample_1 = nn.MaxPool2d(2) |
|
self.downsample_2 = nn.MaxPool2d(2) |
|
self.downsample_3 = nn.MaxPool2d(2) |
|
self.downsample_4 = nn.MaxPool2d(2) |
|
else: |
|
self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p) |
|
self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn1_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) |
|
self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p) |
|
self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn1_2 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) |
|
self.max_pool1 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) |
|
|
|
self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p) |
|
self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn2_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) |
|
self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p) |
|
self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn2_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) |
|
self.max_pool2 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) |
|
|
|
self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p) |
|
self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn3_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) |
|
self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p) |
|
self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn3_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) |
|
self.max_pool3 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) |
|
|
|
self.conv4_1 = nn.Conv2d(128, 256, 3, padding=p) |
|
self.LReLU4_1 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn4_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) |
|
self.conv4_2 = nn.Conv2d(256, 256, 3, padding=p) |
|
self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn4_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) |
|
self.max_pool4 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) |
|
|
|
self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p) |
|
self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn5_1 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512) |
|
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p) |
|
self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn5_2 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512) |
|
|
|
|
|
self.deconv5 = nn.Conv2d(512, 256, 3, padding=p) |
|
self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p) |
|
self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn6_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) |
|
self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p) |
|
self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn6_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) |
|
|
|
|
|
self.deconv6 = nn.Conv2d(256, 128, 3, padding=p) |
|
self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p) |
|
self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn7_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) |
|
self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p) |
|
self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn7_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) |
|
|
|
|
|
self.deconv7 = nn.Conv2d(128, 64, 3, padding=p) |
|
self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p) |
|
self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn8_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) |
|
self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p) |
|
self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn8_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) |
|
|
|
|
|
self.deconv8 = nn.Conv2d(64, 32, 3, padding=p) |
|
self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p) |
|
self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True) |
|
if self.opt.use_norm == 1: |
|
self.bn9_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) |
|
self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p) |
|
self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True) |
|
|
|
self.conv10 = nn.Conv2d(32, 3, 1) |
|
if self.opt.tanh: |
|
self.tanh = nn.Tanh() |
|
|
|
def depth_to_space(self, input, block_size): |
|
block_size_sq = block_size*block_size |
|
output = input.permute(0, 2, 3, 1) |
|
(batch_size, d_height, d_width, d_depth) = output.size() |
|
s_depth = int(d_depth / block_size_sq) |
|
s_width = int(d_width * block_size) |
|
s_height = int(d_height * block_size) |
|
t_1 = output.resize(batch_size, d_height, d_width, block_size_sq, s_depth) |
|
spl = t_1.split(block_size, 3) |
|
stack = [t_t.resize(batch_size, d_height, s_width, s_depth) for t_t in spl] |
|
output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).resize(batch_size, s_height, s_width, s_depth) |
|
output = output.permute(0, 3, 1, 2) |
|
return output |
|
|
|
def forward(self, input, gray): |
|
flag = 0 |
|
if input.size()[3] > 2200: |
|
avg = nn.AvgPool2d(2) |
|
input = avg(input) |
|
gray = avg(gray) |
|
flag = 1 |
|
|
|
input, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(input) |
|
gray, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(gray) |
|
if self.opt.self_attention: |
|
gray_2 = self.downsample_1(gray) |
|
gray_3 = self.downsample_2(gray_2) |
|
gray_4 = self.downsample_3(gray_3) |
|
gray_5 = self.downsample_4(gray_4) |
|
if self.opt.use_norm == 1: |
|
if self.opt.self_attention: |
|
x = self.bn1_1(self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1)))) |
|
|
|
else: |
|
x = self.bn1_1(self.LReLU1_1(self.conv1_1(input))) |
|
conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x))) |
|
x = self.max_pool1(conv1) |
|
|
|
x = self.bn2_1(self.LReLU2_1(self.conv2_1(x))) |
|
conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x))) |
|
x = self.max_pool2(conv2) |
|
|
|
x = self.bn3_1(self.LReLU3_1(self.conv3_1(x))) |
|
conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x))) |
|
x = self.max_pool3(conv3) |
|
|
|
x = self.bn4_1(self.LReLU4_1(self.conv4_1(x))) |
|
conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x))) |
|
x = self.max_pool4(conv4) |
|
|
|
x = self.bn5_1(self.LReLU5_1(self.conv5_1(x))) |
|
x = x*gray_5 if self.opt.self_attention else x |
|
conv5 = self.bn5_2(self.LReLU5_2(self.conv5_2(x))) |
|
|
|
conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear') |
|
conv4 = conv4*gray_4 if self.opt.self_attention else conv4 |
|
up6 = torch.cat([self.deconv5(conv5), conv4], 1) |
|
x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6))) |
|
conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x))) |
|
|
|
conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear') |
|
conv3 = conv3*gray_3 if self.opt.self_attention else conv3 |
|
up7 = torch.cat([self.deconv6(conv6), conv3], 1) |
|
x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7))) |
|
conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x))) |
|
|
|
conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear') |
|
conv2 = conv2*gray_2 if self.opt.self_attention else conv2 |
|
up8 = torch.cat([self.deconv7(conv7), conv2], 1) |
|
x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8))) |
|
conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x))) |
|
|
|
conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear') |
|
conv1 = conv1*gray if self.opt.self_attention else conv1 |
|
up9 = torch.cat([self.deconv8(conv8), conv1], 1) |
|
x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9))) |
|
conv9 = self.LReLU9_2(self.conv9_2(x)) |
|
|
|
latent = self.conv10(conv9) |
|
|
|
if self.opt.times_residual: |
|
latent = latent*gray |
|
|
|
|
|
if self.opt.tanh: |
|
latent = self.tanh(latent) |
|
if self.skip: |
|
if self.opt.linear_add: |
|
if self.opt.latent_threshold: |
|
latent = F.relu(latent) |
|
elif self.opt.latent_norm: |
|
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) |
|
input = (input - torch.min(input))/(torch.max(input) - torch.min(input)) |
|
output = latent + input*self.opt.skip |
|
output = output*2 - 1 |
|
else: |
|
if self.opt.latent_threshold: |
|
latent = F.relu(latent) |
|
elif self.opt.latent_norm: |
|
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) |
|
output = latent + input*self.opt.skip |
|
else: |
|
output = latent |
|
|
|
if self.opt.linear: |
|
output = output/torch.max(torch.abs(output)) |
|
|
|
|
|
elif self.opt.use_norm == 0: |
|
if self.opt.self_attention: |
|
x = self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1))) |
|
else: |
|
x = self.LReLU1_1(self.conv1_1(input)) |
|
conv1 = self.LReLU1_2(self.conv1_2(x)) |
|
x = self.max_pool1(conv1) |
|
|
|
x = self.LReLU2_1(self.conv2_1(x)) |
|
conv2 = self.LReLU2_2(self.conv2_2(x)) |
|
x = self.max_pool2(conv2) |
|
|
|
x = self.LReLU3_1(self.conv3_1(x)) |
|
conv3 = self.LReLU3_2(self.conv3_2(x)) |
|
x = self.max_pool3(conv3) |
|
|
|
x = self.LReLU4_1(self.conv4_1(x)) |
|
conv4 = self.LReLU4_2(self.conv4_2(x)) |
|
x = self.max_pool4(conv4) |
|
|
|
x = self.LReLU5_1(self.conv5_1(x)) |
|
x = x*gray_5 if self.opt.self_attention else x |
|
conv5 = self.LReLU5_2(self.conv5_2(x)) |
|
|
|
conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear') |
|
conv4 = conv4*gray_4 if self.opt.self_attention else conv4 |
|
up6 = torch.cat([self.deconv5(conv5), conv4], 1) |
|
x = self.LReLU6_1(self.conv6_1(up6)) |
|
conv6 = self.LReLU6_2(self.conv6_2(x)) |
|
|
|
conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear') |
|
conv3 = conv3*gray_3 if self.opt.self_attention else conv3 |
|
up7 = torch.cat([self.deconv6(conv6), conv3], 1) |
|
x = self.LReLU7_1(self.conv7_1(up7)) |
|
conv7 = self.LReLU7_2(self.conv7_2(x)) |
|
|
|
conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear') |
|
conv2 = conv2*gray_2 if self.opt.self_attention else conv2 |
|
up8 = torch.cat([self.deconv7(conv7), conv2], 1) |
|
x = self.LReLU8_1(self.conv8_1(up8)) |
|
conv8 = self.LReLU8_2(self.conv8_2(x)) |
|
|
|
conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear') |
|
conv1 = conv1*gray if self.opt.self_attention else conv1 |
|
up9 = torch.cat([self.deconv8(conv8), conv1], 1) |
|
x = self.LReLU9_1(self.conv9_1(up9)) |
|
conv9 = self.LReLU9_2(self.conv9_2(x)) |
|
|
|
latent = self.conv10(conv9) |
|
|
|
if self.opt.times_residual: |
|
latent = latent*gray |
|
|
|
if self.opt.tanh: |
|
latent = self.tanh(latent) |
|
if self.skip: |
|
if self.opt.linear_add: |
|
if self.opt.latent_threshold: |
|
latent = F.relu(latent) |
|
elif self.opt.latent_norm: |
|
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) |
|
input = (input - torch.min(input))/(torch.max(input) - torch.min(input)) |
|
output = latent + input*self.opt.skip |
|
output = output*2 - 1 |
|
else: |
|
if self.opt.latent_threshold: |
|
latent = F.relu(latent) |
|
elif self.opt.latent_norm: |
|
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) |
|
output = latent + input*self.opt.skip |
|
else: |
|
output = latent |
|
|
|
if self.opt.linear: |
|
output = output/torch.max(torch.abs(output)) |
|
|
|
output = pad_tensor_back(output, pad_left, pad_right, pad_top, pad_bottom) |
|
latent = pad_tensor_back(latent, pad_left, pad_right, pad_top, pad_bottom) |
|
gray = pad_tensor_back(gray, pad_left, pad_right, pad_top, pad_bottom) |
|
if flag == 1: |
|
output = F.upsample(output, scale_factor=2, mode='bilinear') |
|
gray = F.upsample(gray, scale_factor=2, mode='bilinear') |
|
if self.skip: |
|
return output, latent |
|
else: |
|
return output |
|
|
|
class DnCNN(nn.Module): |
|
def __init__(self, opt=None, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3): |
|
super(DnCNN, self).__init__() |
|
kernel_size = 3 |
|
padding = 1 |
|
layers = [] |
|
|
|
layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True)) |
|
layers.append(nn.ReLU(inplace=True)) |
|
for _ in range(depth-2): |
|
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False)) |
|
layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95)) |
|
layers.append(nn.ReLU(inplace=True)) |
|
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False)) |
|
self.dncnn = nn.Sequential(*layers) |
|
self._initialize_weights() |
|
|
|
def forward(self, x): |
|
y = x |
|
out = self.dncnn(x) |
|
return y+out |
|
|
|
def _initialize_weights(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
init.orthogonal_(m.weight) |
|
print('init weight') |
|
if m.bias is not None: |
|
init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.BatchNorm2d): |
|
init.constant_(m.weight, 1) |
|
init.constant_(m.bias, 0) |
|
|
|
class Vgg16(nn.Module): |
|
def __init__(self): |
|
super(Vgg16, self).__init__() |
|
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) |
|
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) |
|
|
|
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) |
|
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) |
|
|
|
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) |
|
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) |
|
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) |
|
|
|
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) |
|
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) |
|
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) |
|
|
|
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) |
|
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) |
|
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) |
|
|
|
def forward(self, X, opt): |
|
h = F.relu(self.conv1_1(X), inplace=True) |
|
h = F.relu(self.conv1_2(h), inplace=True) |
|
|
|
h = F.max_pool2d(h, kernel_size=2, stride=2) |
|
|
|
h = F.relu(self.conv2_1(h), inplace=True) |
|
h = F.relu(self.conv2_2(h), inplace=True) |
|
|
|
h = F.max_pool2d(h, kernel_size=2, stride=2) |
|
|
|
h = F.relu(self.conv3_1(h), inplace=True) |
|
h = F.relu(self.conv3_2(h), inplace=True) |
|
h = F.relu(self.conv3_3(h), inplace=True) |
|
|
|
if opt.vgg_choose != "no_maxpool": |
|
h = F.max_pool2d(h, kernel_size=2, stride=2) |
|
|
|
h = F.relu(self.conv4_1(h), inplace=True) |
|
relu4_1 = h |
|
h = F.relu(self.conv4_2(h), inplace=True) |
|
relu4_2 = h |
|
conv4_3 = self.conv4_3(h) |
|
h = F.relu(conv4_3, inplace=True) |
|
relu4_3 = h |
|
|
|
if opt.vgg_choose != "no_maxpool": |
|
if opt.vgg_maxpooling: |
|
h = F.max_pool2d(h, kernel_size=2, stride=2) |
|
|
|
relu5_1 = F.relu(self.conv5_1(h), inplace=True) |
|
relu5_2 = F.relu(self.conv5_2(relu5_1), inplace=True) |
|
conv5_3 = self.conv5_3(relu5_2) |
|
h = F.relu(conv5_3, inplace=True) |
|
relu5_3 = h |
|
if opt.vgg_choose == "conv4_3": |
|
return conv4_3 |
|
elif opt.vgg_choose == "relu4_2": |
|
return relu4_2 |
|
elif opt.vgg_choose == "relu4_1": |
|
return relu4_1 |
|
elif opt.vgg_choose == "relu4_3": |
|
return relu4_3 |
|
elif opt.vgg_choose == "conv5_3": |
|
return conv5_3 |
|
elif opt.vgg_choose == "relu5_1": |
|
return relu5_1 |
|
elif opt.vgg_choose == "relu5_2": |
|
return relu5_2 |
|
elif opt.vgg_choose == "relu5_3" or "maxpool": |
|
return relu5_3 |
|
|
|
def vgg_preprocess(batch, opt): |
|
tensortype = type(batch.data) |
|
(r, g, b) = torch.chunk(batch, 3, dim = 1) |
|
batch = torch.cat((b, g, r), dim = 1) |
|
batch = (batch + 1) * 255 * 0.5 |
|
if opt.vgg_mean: |
|
mean = tensortype(batch.data.size()) |
|
mean[:, 0, :, :] = 103.939 |
|
mean[:, 1, :, :] = 116.779 |
|
mean[:, 2, :, :] = 123.680 |
|
batch = batch.sub(Variable(mean)) |
|
return batch |
|
|
|
class PerceptualLoss(nn.Module): |
|
def __init__(self, opt): |
|
super(PerceptualLoss, self).__init__() |
|
self.opt = opt |
|
self.instancenorm = nn.InstanceNorm2d(512, affine=False) |
|
|
|
def compute_vgg_loss(self, vgg, img, target): |
|
img_vgg = vgg_preprocess(img, self.opt) |
|
target_vgg = vgg_preprocess(target, self.opt) |
|
img_fea = vgg(img_vgg, self.opt) |
|
target_fea = vgg(target_vgg, self.opt) |
|
if self.opt.no_vgg_instance: |
|
return torch.mean((img_fea - target_fea) ** 2) |
|
else: |
|
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) |
|
|
|
def load_vgg16(model_dir, gpu_ids): |
|
""" Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """ |
|
if not os.path.exists(model_dir): |
|
os.mkdir(model_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vgg = Vgg16() |
|
|
|
vgg.cuda(device=gpu_ids[0]) |
|
vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight'))) |
|
vgg = torch.nn.DataParallel(vgg, gpu_ids) |
|
return vgg |
|
|
|
|
|
|
|
class FCN32s(nn.Module): |
|
def __init__(self, n_class=21): |
|
super(FCN32s, self).__init__() |
|
|
|
self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) |
|
self.relu1_1 = nn.ReLU(inplace=True) |
|
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) |
|
self.relu1_2 = nn.ReLU(inplace=True) |
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
|
|
|
|
self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) |
|
self.relu2_1 = nn.ReLU(inplace=True) |
|
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) |
|
self.relu2_2 = nn.ReLU(inplace=True) |
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
|
|
|
|
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) |
|
self.relu3_1 = nn.ReLU(inplace=True) |
|
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) |
|
self.relu3_2 = nn.ReLU(inplace=True) |
|
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) |
|
self.relu3_3 = nn.ReLU(inplace=True) |
|
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
|
|
|
|
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) |
|
self.relu4_1 = nn.ReLU(inplace=True) |
|
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) |
|
self.relu4_2 = nn.ReLU(inplace=True) |
|
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) |
|
self.relu4_3 = nn.ReLU(inplace=True) |
|
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
|
|
|
|
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) |
|
self.relu5_1 = nn.ReLU(inplace=True) |
|
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) |
|
self.relu5_2 = nn.ReLU(inplace=True) |
|
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) |
|
self.relu5_3 = nn.ReLU(inplace=True) |
|
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
|
|
|
|
self.fc6 = nn.Conv2d(512, 4096, 7) |
|
self.relu6 = nn.ReLU(inplace=True) |
|
self.drop6 = nn.Dropout2d() |
|
|
|
|
|
self.fc7 = nn.Conv2d(4096, 4096, 1) |
|
self.relu7 = nn.ReLU(inplace=True) |
|
self.drop7 = nn.Dropout2d() |
|
|
|
self.score_fr = nn.Conv2d(4096, n_class, 1) |
|
self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32, |
|
bias=False) |
|
|
|
|
|
def _initialize_weights(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
m.weight.data.zero_() |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
if isinstance(m, nn.ConvTranspose2d): |
|
assert m.kernel_size[0] == m.kernel_size[1] |
|
initial_weight = get_upsampling_weight( |
|
m.in_channels, m.out_channels, m.kernel_size[0]) |
|
m.weight.data.copy_(initial_weight) |
|
|
|
def forward(self, x): |
|
h = x |
|
h = self.relu1_1(self.conv1_1(h)) |
|
h = self.relu1_2(self.conv1_2(h)) |
|
h = self.pool1(h) |
|
|
|
h = self.relu2_1(self.conv2_1(h)) |
|
h = self.relu2_2(self.conv2_2(h)) |
|
h = self.pool2(h) |
|
|
|
h = self.relu3_1(self.conv3_1(h)) |
|
h = self.relu3_2(self.conv3_2(h)) |
|
h = self.relu3_3(self.conv3_3(h)) |
|
h = self.pool3(h) |
|
|
|
h = self.relu4_1(self.conv4_1(h)) |
|
h = self.relu4_2(self.conv4_2(h)) |
|
h = self.relu4_3(self.conv4_3(h)) |
|
h = self.pool4(h) |
|
|
|
h = self.relu5_1(self.conv5_1(h)) |
|
h = self.relu5_2(self.conv5_2(h)) |
|
h = self.relu5_3(self.conv5_3(h)) |
|
h = self.pool5(h) |
|
|
|
h = self.relu6(self.fc6(h)) |
|
h = self.drop6(h) |
|
|
|
h = self.relu7(self.fc7(h)) |
|
h = self.drop7(h) |
|
|
|
h = self.score_fr(h) |
|
|
|
h = self.upscore(h) |
|
h = h[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous() |
|
return h |
|
|
|
def load_fcn(model_dir): |
|
fcn = FCN32s() |
|
fcn.load_state_dict(torch.load(os.path.join(model_dir, 'fcn32s_from_caffe.pth'))) |
|
fcn.cuda() |
|
return fcn |
|
|
|
class SemanticLoss(nn.Module): |
|
def __init__(self, opt): |
|
super(SemanticLoss, self).__init__() |
|
self.opt = opt |
|
self.instancenorm = nn.InstanceNorm2d(21, affine=False) |
|
|
|
def compute_fcn_loss(self, fcn, img, target): |
|
img_fcn = vgg_preprocess(img, self.opt) |
|
target_fcn = vgg_preprocess(target, self.opt) |
|
img_fea = fcn(img_fcn) |
|
target_fea = fcn(target_fcn) |
|
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) |
|
|