|
import numpy as np |
|
import torch |
|
import os |
|
from collections import OrderedDict |
|
from torch.autograd import Variable |
|
import util.util as util |
|
from collections import OrderedDict |
|
from torch.autograd import Variable |
|
import itertools |
|
import util.util as util |
|
from util.image_pool import ImagePool |
|
from .base_model import BaseModel |
|
from . import networks |
|
import sys |
|
|
|
|
|
class PairModel(BaseModel): |
|
def name(self): |
|
return 'CycleGANModel' |
|
|
|
def initialize(self, opt): |
|
BaseModel.initialize(self, opt) |
|
|
|
nb = opt.batchSize |
|
size = opt.fineSize |
|
self.opt = opt |
|
self.input_A = self.Tensor(nb, opt.input_nc, size, size) |
|
self.input_B = self.Tensor(nb, opt.output_nc, size, size) |
|
self.input_img = self.Tensor(nb, opt.input_nc, size, size) |
|
self.input_A_gray = self.Tensor(nb, 1, size, size) |
|
|
|
if opt.vgg > 0: |
|
self.vgg_loss = networks.PerceptualLoss() |
|
self.vgg_loss.cuda() |
|
self.vgg = networks.load_vgg16("./model") |
|
self.vgg.eval() |
|
for param in self.vgg.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
|
skip = True if opt.skip > 0 else False |
|
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, |
|
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt) |
|
|
|
if not self.isTrain or opt.continue_train: |
|
which_epoch = opt.which_epoch |
|
self.load_network(self.netG_A, 'G_A', which_epoch) |
|
|
|
if self.isTrain: |
|
self.old_lr = opt.lr |
|
self.fake_A_pool = ImagePool(opt.pool_size) |
|
self.fake_B_pool = ImagePool(opt.pool_size) |
|
|
|
if opt.use_wgan: |
|
self.criterionGAN = networks.DiscLossWGANGP() |
|
else: |
|
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) |
|
if opt.use_mse: |
|
self.criterionCycle = torch.nn.MSELoss() |
|
else: |
|
self.criterionCycle = torch.nn.L1Loss() |
|
self.criterionL1 = torch.nn.L1Loss() |
|
self.criterionIdt = torch.nn.L1Loss() |
|
|
|
self.optimizer_G = torch.optim.Adam(self.netG_A.parameters(), |
|
lr=opt.lr, betas=(opt.beta1, 0.999)) |
|
|
|
print('---------- Networks initialized -------------') |
|
networks.print_network(self.netG_A) |
|
if opt.isTrain: |
|
self.netG_A.train() |
|
else: |
|
self.netG_A.eval() |
|
print('-----------------------------------------------') |
|
|
|
def set_input(self, input): |
|
AtoB = self.opt.which_direction == 'AtoB' |
|
input_A = input['A' if AtoB else 'B'] |
|
input_B = input['B' if AtoB else 'A'] |
|
input_img = input['input_img'] |
|
input_A_gray = input['A_gray'] |
|
self.input_A.resize_(input_A.size()).copy_(input_A) |
|
self.input_A_gray.resize_(input_A_gray.size()).copy_(input_A_gray) |
|
self.input_B.resize_(input_B.size()).copy_(input_B) |
|
self.input_img.resize_(input_img.size()).copy_(input_img) |
|
self.image_paths = input['A_paths' if AtoB else 'B_paths'] |
|
|
|
def forward(self): |
|
self.real_A = Variable(self.input_A) |
|
self.real_B = Variable(self.input_B) |
|
self.real_A_gray = Variable(self.input_A_gray) |
|
self.real_img = Variable(self.input_img) |
|
|
|
|
|
def test(self): |
|
self.real_A = Variable(self.input_A, volatile=True) |
|
self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) |
|
|
|
self.real_B = Variable(self.input_B, volatile=True) |
|
|
|
def predict(self): |
|
self.real_A = Variable(self.input_A, volatile=True) |
|
self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) |
|
|
|
real_A = util.tensor2im(self.real_A.data) |
|
fake_B = util.tensor2im(self.fake_B.data) |
|
if self.opt.skip == 1: |
|
latent_real_A = util.tensor2im(self.latent_real_A.data) |
|
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("latent_real_A", latent_real_A)]) |
|
else: |
|
return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) |
|
|
|
|
|
def get_image_paths(self): |
|
return self.image_paths |
|
|
|
def backward_G(self): |
|
|
|
self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) |
|
|
|
self.L1_AB = self.criterionL1(self.fake_B, self.real_B) * self.opt.l1 |
|
self.loss_G = self.L1_AB |
|
self.loss_G.backward() |
|
|
|
|
|
def optimize_parameters(self, epoch): |
|
|
|
self.forward() |
|
|
|
self.optimizer_G.zero_grad() |
|
self.backward_G() |
|
self.optimizer_G.step() |
|
|
|
|
|
def get_current_errors(self, epoch): |
|
L1 = self.L1_AB.data[0] |
|
loss_G = self.loss_G.data[0] |
|
return OrderedDict([('L1', L1), ('loss_G', loss_G)]) |
|
|
|
def get_current_visuals(self): |
|
real_A = util.tensor2im(self.real_A.data) |
|
fake_B = util.tensor2im(self.fake_B.data) |
|
real_B = util.tensor2im(self.real_B.data) |
|
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) |
|
|
|
def save(self, label): |
|
self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) |
|
|
|
def update_learning_rate(self): |
|
|
|
if self.opt.new_lr: |
|
lr = self.old_lr/2 |
|
else: |
|
lrd = self.opt.lr / self.opt.niter_decay |
|
lr = self.old_lr - lrd |
|
for param_group in self.optimizer_G.param_groups: |
|
param_group['lr'] = lr |
|
|
|
print('update learning rate: %f -> %f' % (self.old_lr, lr)) |
|
self.old_lr = lr |
|
|