import numpy as np import torch import os from collections import OrderedDict from torch.autograd import Variable import util.util as util from util.image_pool import ImagePool from .base_model import BaseModel from . import networks class Pix2PixModel(BaseModel): def name(self): return 'Pix2PixModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain # define tensors self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch) if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths def get_image_paths(self): return self.image_paths def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) self.pred_fake = self.netD.forward(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(self.pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) self.pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD.forward(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), ('G_L1', self.loss_G_L1.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0]) ]) def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) real_B = util.tensor2im(self.real_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr