EnlightenGAN / models /unit_model.py
HenryGong's picture
Upload 84 files
aba0e05 verified
import numpy as np
import torch
import os
from collections import OrderedDict
from torch.autograd import Variable
import itertools
import util.util as util
from util.util import weights_init, get_model_list, vgg_preprocess, load_vgg16, get_scheduler
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
from .unit_network import *
import sys
def get_config(config):
import yaml
with open(config, 'r') as stream:
return yaml.load(stream)
class UNITModel(BaseModel):
def name(self):
return 'UNITModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
self.config = get_config(opt.config)
nb = opt.batchSize
size = opt.fineSize
self.input_A = self.Tensor(nb, opt.input_nc, size, size)
self.input_B = self.Tensor(nb, opt.output_nc, size, size)
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.gen_a = VAEGen(self.config['input_dim_a'], self.config['gen'])
self.gen_b = VAEGen(self.config['input_dim_a'], self.config['gen'])
if self.isTrain:
self.dis_a = MsImageDis(self.config['input_dim_a'], self.config['dis']) # discriminator for domain a
self.dis_b = MsImageDis(self.config['input_dim_b'], self.config['dis']) # discriminator for domain b
if not self.isTrain or opt.continue_train:
which_epoch = opt.which_epoch
self.load_network(self.gen_a, 'G_A', which_epoch)
self.load_network(self.gen_b, 'G_B', which_epoch)
if self.isTrain:
self.load_network(self.dis_a, 'D_A', which_epoch)
self.load_network(self.dis_b, 'D_B', which_epoch)
if self.isTrain:
self.old_lr = self.config['lr']
self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
# Setup the optimizers
beta1 = self.config['beta1']
beta2 = self.config['beta2']
dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
lr=self.config['lr'], betas=(beta1, beta2), weight_decay=self.config['weight_decay'])
self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
lr=self.config['lr'], betas=(beta1, beta2), weight_decay=self.config['weight_decay'])
self.dis_scheduler = get_scheduler(self.dis_opt, self.config)
self.gen_scheduler = get_scheduler(self.gen_opt, self.config)
# Network weight initialization
# self.apply(weights_init(self.config['init']))
self.dis_a.apply(weights_init('gaussian'))
self.dis_b.apply(weights_init('gaussian'))
# Load VGG model if needed
if 'vgg_w' in self.config.keys() and self.config['vgg_w'] > 0:
self.vgg = load_vgg16(self.config['vgg_model_path'] + '/models')
self.vgg.eval()
for param in self.vgg.parameters():
param.requires_grad = False
self.gen_a.cuda()
self.gen_b.cuda()
self.dis_a.cuda()
self.dis_b.cuda()
print('---------- Networks initialized -------------')
networks.print_network(self.gen_a)
networks.print_network(self.gen_b)
if self.isTrain:
networks.print_network(self.dis_a)
networks.print_network(self.dis_b)
print('-----------------------------------------------')
def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
input_B = input['B' if AtoB else 'A']
self.input_A.resize_(input_A.size()).copy_(input_A)
self.input_B.resize_(input_B.size()).copy_(input_B)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
self.real_A = Variable(self.input_A.cuda())
self.real_B = Variable(self.input_B.cuda())
# def forward(self):
# self.real_A = Variable(self.input_A)
# self.real_B = Variable(self.input_B)
def test(self):
self.real_A = Variable(self.input_A.cuda(), volatile=True)
self.real_B = Variable(self.input_B.cuda(), volatile=True)
h_a, n_a = self.gen_a.encode(self.real_A)
h_b, n_b = self.gen_b.encode(self.real_B)
x_a_recon = self.gen_a.decode(h_a + n_a) + x_a*1
x_b_recon = self.gen_b.decode(h_b + n_b) + x_b*1
x_ba = self.gen_a.decode(h_b + n_b) + x_b*1
x_ab = self.gen_b.decode(h_a + n_a) + x_a*1
h_b_recon, n_b_recon = self.gen_a.encode(x_ba)
h_a_recon, n_a_recon = self.gen_b.encode(x_ab)
x_aba = self.gen_a.decode(h_a_recon + n_a_recon) + x_ab*1 if self.config['recon_x_cyc_w'] > 0 else None
x_bab = self.gen_b.decode(h_b_recon + n_b_recon) + x_ba*1 if self.config['recon_x_cyc_w'] > 0 else None
self.x_a_recon, self.x_ab, self.x_aba = x_a_recon, x_ab, x_aba
self.x_b_recon, self.x_ba, self.x_bab = x_b_recon, x_ba, x_bab
# get image paths
def get_image_paths(self):
return self.image_paths
def optimize_parameters(self):
self.gen_update(self.real_A, self.real_B)
self.dis_update(self.real_A, self.real_B)
def recon_criterion(self, input, target):
return torch.mean(torch.abs(input - target))
def forward(self, x_a, x_b):
self.eval()
x_a.volatile = True
x_b.volatile = True
h_a, _ = self.gen_a.encode(x_a)
h_b, _ = self.gen_b.encode(x_b)
x_ba = self.gen_a.decode(h_b)
x_ab = self.gen_b.decode(h_a)
self.train()
return x_ab, x_ba
def __compute_kl(self, mu):
# def _compute_kl(self, mu, sd):
# mu_2 = torch.pow(mu, 2)
# sd_2 = torch.pow(sd, 2)
# encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
# return encoding_loss
mu_2 = torch.pow(mu, 2)
encoding_loss = torch.mean(mu_2)
return encoding_loss
def gen_update(self, x_a, x_b):
self.gen_opt.zero_grad()
# encode
h_a, n_a = self.gen_a.encode(x_a)
h_b, n_b = self.gen_b.encode(x_b)
# decode (within domain)
x_a_recon = self.gen_a.decode(h_a + n_a) + 0*x_a
x_b_recon = self.gen_b.decode(h_b + n_b) + 0*x_b
# decode (cross domain)
x_ba = self.gen_a.decode(h_b + n_b) + 0*x_b
x_ab = self.gen_b.decode(h_a + n_a) + 0*x_a
# encode again
h_b_recon, n_b_recon = self.gen_a.encode(x_ba)
h_a_recon, n_a_recon = self.gen_b.encode(x_ab)
# decode again (if needed)
x_aba = self.gen_a.decode(h_a_recon + n_a_recon) + 0*x_ab if self.config['recon_x_cyc_w'] > 0 else None
x_bab = self.gen_b.decode(h_b_recon + n_b_recon) + 0*x_ba if self.config['recon_x_cyc_w'] > 0 else None
# reconstruction loss
self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)
# GAN loss
self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
# domain-invariant perceptual loss
self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if self.config['vgg_w'] > 0 else 0
self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if self.config['vgg_w'] > 0 else 0
# total loss
self.loss_gen_total = self.config['gan_w'] * self.loss_gen_adv_a + \
self.config['gan_w'] * self.loss_gen_adv_b + \
self.config['recon_x_w'] * self.loss_gen_recon_x_a + \
self.config['recon_kl_w'] * self.loss_gen_recon_kl_a + \
self.config['recon_x_w'] * self.loss_gen_recon_x_b + \
self.config['recon_kl_w'] * self.loss_gen_recon_kl_b + \
self.config['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
self.config['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
self.config['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
self.config['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
self.config['vgg_w'] * self.loss_gen_vgg_a + \
self.config['vgg_w'] * self.loss_gen_vgg_b
self.loss_gen_total.backward()
self.gen_opt.step()
self.x_a_recon, self.x_ab, self.x_aba = x_a_recon, x_ab, x_aba
self.x_b_recon, self.x_ba, self.x_bab = x_b_recon, x_ba, x_bab
def compute_vgg_loss(self, vgg, img, target):
img_vgg = vgg_preprocess(img)
target_vgg = vgg_preprocess(target)
img_fea = vgg(img_vgg)
target_fea = vgg(target_vgg)
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)
def dis_update(self, x_a, x_b):
self.dis_opt.zero_grad()
# encode
h_a, n_a = self.gen_a.encode(x_a)
h_b, n_b = self.gen_b.encode(x_b)
# decode (cross domain)
x_ba = self.gen_a.decode(h_b + n_b)
x_ab = self.gen_b.decode(h_a + n_a)
# D loss
self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
self.loss_dis_total = self.config['gan_w'] * self.loss_dis_a + self.config['gan_w'] * self.loss_dis_b
self.loss_dis_total.backward()
self.dis_opt.step()
def get_current_errors(self):
D_A = self.loss_dis_a.data[0]
G_A = self.loss_gen_adv_a.data[0]
kl_A = self.loss_gen_recon_kl_a.data[0]
Cyc_A = self.loss_gen_cyc_x_a.data[0]
D_B = self.loss_dis_b.data[0]
G_B = self.loss_gen_adv_b.data[0]
kl_B = self.loss_gen_recon_kl_b.data[0]
Cyc_B = self.loss_gen_cyc_x_b.data[0]
if self.config['vgg_w'] > 0:
vgg_A = self.loss_gen_vgg_a
vgg_B = self.loss_gen_vgg_b
return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('kl_A', kl_A), ('vgg_A', vgg_A),
('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('kl_B', kl_B), ('vgg_B', vgg_B)])
else:
return OrderedDict([('D_A', D_A), ('G_A', G_A), ('kl_A', kl_A), ('Cyc_A', Cyc_A),
('D_B', D_B), ('G_B', G_B), ('kl_B', kl_B), ('Cyc_B', Cyc_B)])
def get_current_visuals(self):
real_A = util.tensor2im(self.real_A.data)
recon_A = util.tensor2im(self.x_a_recon.data)
A_B = util.tensor2im(self.x_ab.data)
ABA = util.tensor2im(self.x_aba.data)
real_B = util.tensor2im(self.real_B.data)
recon_B = util.tensor2im(self.x_b_recon.data)
B_A = util.tensor2im(self.x_ba.data)
BAB = util.tensor2im(self.x_b_recon.data)
return OrderedDict([('real_A', real_A), ('A_B', A_B), ('recon_A', recon_A), ('ABA', ABA),
('real_B', real_B), ('B_A', B_A), ('recon_B', recon_B), ('BAB', BAB)])
def save(self, label):
self.save_network(self.gen_a, 'G_A', label, self.gpu_ids)
self.save_network(self.dis_a, 'D_A', label, self.gpu_ids)
self.save_network(self.gen_b, 'G_B', label, self.gpu_ids)
self.save_network(self.dis_b, 'D_B', label, self.gpu_ids)
def update_learning_rate(self):
lrd = self.config['lr'] / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.gen_a.param_groups:
param_group['lr'] = lr
for param_group in self.gen_b.param_groups:
param_group['lr'] = lr
for param_group in self.dis_a.param_groups:
param_group['lr'] = lr
for param_group in self.dis_b.param_groups:
param_group['lr'] = lr
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr