|
import numpy as np |
|
import torch |
|
import os |
|
from collections import OrderedDict |
|
from torch.autograd import Variable |
|
import itertools |
|
import util.util as util |
|
from util.util import weights_init, get_model_list, vgg_preprocess, load_vgg16, get_scheduler |
|
from util.image_pool import ImagePool |
|
from .base_model import BaseModel |
|
from . import networks |
|
from .unit_network import * |
|
import sys |
|
|
|
def get_config(config): |
|
import yaml |
|
with open(config, 'r') as stream: |
|
return yaml.load(stream) |
|
|
|
class UNITModel(BaseModel): |
|
def name(self): |
|
return 'UNITModel' |
|
|
|
def initialize(self, opt): |
|
BaseModel.initialize(self, opt) |
|
|
|
self.config = get_config(opt.config) |
|
nb = opt.batchSize |
|
size = opt.fineSize |
|
self.input_A = self.Tensor(nb, opt.input_nc, size, size) |
|
self.input_B = self.Tensor(nb, opt.output_nc, size, size) |
|
|
|
|
|
|
|
|
|
|
|
self.gen_a = VAEGen(self.config['input_dim_a'], self.config['gen']) |
|
self.gen_b = VAEGen(self.config['input_dim_a'], self.config['gen']) |
|
|
|
if self.isTrain: |
|
self.dis_a = MsImageDis(self.config['input_dim_a'], self.config['dis']) |
|
self.dis_b = MsImageDis(self.config['input_dim_b'], self.config['dis']) |
|
if not self.isTrain or opt.continue_train: |
|
which_epoch = opt.which_epoch |
|
self.load_network(self.gen_a, 'G_A', which_epoch) |
|
self.load_network(self.gen_b, 'G_B', which_epoch) |
|
if self.isTrain: |
|
self.load_network(self.dis_a, 'D_A', which_epoch) |
|
self.load_network(self.dis_b, 'D_B', which_epoch) |
|
|
|
if self.isTrain: |
|
self.old_lr = self.config['lr'] |
|
self.fake_A_pool = ImagePool(opt.pool_size) |
|
self.fake_B_pool = ImagePool(opt.pool_size) |
|
|
|
|
|
beta1 = self.config['beta1'] |
|
beta2 = self.config['beta2'] |
|
dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) |
|
gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) |
|
self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], |
|
lr=self.config['lr'], betas=(beta1, beta2), weight_decay=self.config['weight_decay']) |
|
self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], |
|
lr=self.config['lr'], betas=(beta1, beta2), weight_decay=self.config['weight_decay']) |
|
self.dis_scheduler = get_scheduler(self.dis_opt, self.config) |
|
self.gen_scheduler = get_scheduler(self.gen_opt, self.config) |
|
|
|
|
|
|
|
self.dis_a.apply(weights_init('gaussian')) |
|
self.dis_b.apply(weights_init('gaussian')) |
|
|
|
|
|
if 'vgg_w' in self.config.keys() and self.config['vgg_w'] > 0: |
|
self.vgg = load_vgg16(self.config['vgg_model_path'] + '/models') |
|
self.vgg.eval() |
|
for param in self.vgg.parameters(): |
|
param.requires_grad = False |
|
self.gen_a.cuda() |
|
self.gen_b.cuda() |
|
self.dis_a.cuda() |
|
self.dis_b.cuda() |
|
|
|
print('---------- Networks initialized -------------') |
|
networks.print_network(self.gen_a) |
|
networks.print_network(self.gen_b) |
|
if self.isTrain: |
|
networks.print_network(self.dis_a) |
|
networks.print_network(self.dis_b) |
|
print('-----------------------------------------------') |
|
|
|
def set_input(self, input): |
|
AtoB = self.opt.which_direction == 'AtoB' |
|
input_A = input['A' if AtoB else 'B'] |
|
input_B = input['B' if AtoB else 'A'] |
|
self.input_A.resize_(input_A.size()).copy_(input_A) |
|
self.input_B.resize_(input_B.size()).copy_(input_B) |
|
self.image_paths = input['A_paths' if AtoB else 'B_paths'] |
|
self.real_A = Variable(self.input_A.cuda()) |
|
self.real_B = Variable(self.input_B.cuda()) |
|
|
|
|
|
|
|
|
|
|
|
def test(self): |
|
self.real_A = Variable(self.input_A.cuda(), volatile=True) |
|
self.real_B = Variable(self.input_B.cuda(), volatile=True) |
|
h_a, n_a = self.gen_a.encode(self.real_A) |
|
h_b, n_b = self.gen_b.encode(self.real_B) |
|
x_a_recon = self.gen_a.decode(h_a + n_a) + x_a*1 |
|
x_b_recon = self.gen_b.decode(h_b + n_b) + x_b*1 |
|
x_ba = self.gen_a.decode(h_b + n_b) + x_b*1 |
|
x_ab = self.gen_b.decode(h_a + n_a) + x_a*1 |
|
h_b_recon, n_b_recon = self.gen_a.encode(x_ba) |
|
h_a_recon, n_a_recon = self.gen_b.encode(x_ab) |
|
x_aba = self.gen_a.decode(h_a_recon + n_a_recon) + x_ab*1 if self.config['recon_x_cyc_w'] > 0 else None |
|
x_bab = self.gen_b.decode(h_b_recon + n_b_recon) + x_ba*1 if self.config['recon_x_cyc_w'] > 0 else None |
|
self.x_a_recon, self.x_ab, self.x_aba = x_a_recon, x_ab, x_aba |
|
self.x_b_recon, self.x_ba, self.x_bab = x_b_recon, x_ba, x_bab |
|
|
|
|
|
def get_image_paths(self): |
|
return self.image_paths |
|
|
|
|
|
def optimize_parameters(self): |
|
self.gen_update(self.real_A, self.real_B) |
|
self.dis_update(self.real_A, self.real_B) |
|
|
|
def recon_criterion(self, input, target): |
|
return torch.mean(torch.abs(input - target)) |
|
|
|
def forward(self, x_a, x_b): |
|
self.eval() |
|
x_a.volatile = True |
|
x_b.volatile = True |
|
h_a, _ = self.gen_a.encode(x_a) |
|
h_b, _ = self.gen_b.encode(x_b) |
|
x_ba = self.gen_a.decode(h_b) |
|
x_ab = self.gen_b.decode(h_a) |
|
self.train() |
|
return x_ab, x_ba |
|
|
|
def __compute_kl(self, mu): |
|
|
|
|
|
|
|
|
|
|
|
mu_2 = torch.pow(mu, 2) |
|
encoding_loss = torch.mean(mu_2) |
|
return encoding_loss |
|
|
|
def gen_update(self, x_a, x_b): |
|
self.gen_opt.zero_grad() |
|
|
|
h_a, n_a = self.gen_a.encode(x_a) |
|
h_b, n_b = self.gen_b.encode(x_b) |
|
|
|
x_a_recon = self.gen_a.decode(h_a + n_a) + 0*x_a |
|
x_b_recon = self.gen_b.decode(h_b + n_b) + 0*x_b |
|
|
|
x_ba = self.gen_a.decode(h_b + n_b) + 0*x_b |
|
x_ab = self.gen_b.decode(h_a + n_a) + 0*x_a |
|
|
|
h_b_recon, n_b_recon = self.gen_a.encode(x_ba) |
|
h_a_recon, n_a_recon = self.gen_b.encode(x_ab) |
|
|
|
x_aba = self.gen_a.decode(h_a_recon + n_a_recon) + 0*x_ab if self.config['recon_x_cyc_w'] > 0 else None |
|
x_bab = self.gen_b.decode(h_b_recon + n_b_recon) + 0*x_ba if self.config['recon_x_cyc_w'] > 0 else None |
|
|
|
|
|
self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) |
|
self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) |
|
self.loss_gen_recon_kl_a = self.__compute_kl(h_a) |
|
self.loss_gen_recon_kl_b = self.__compute_kl(h_b) |
|
self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) |
|
self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) |
|
self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) |
|
self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) |
|
|
|
self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) |
|
self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) |
|
|
|
self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if self.config['vgg_w'] > 0 else 0 |
|
self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if self.config['vgg_w'] > 0 else 0 |
|
|
|
self.loss_gen_total = self.config['gan_w'] * self.loss_gen_adv_a + \ |
|
self.config['gan_w'] * self.loss_gen_adv_b + \ |
|
self.config['recon_x_w'] * self.loss_gen_recon_x_a + \ |
|
self.config['recon_kl_w'] * self.loss_gen_recon_kl_a + \ |
|
self.config['recon_x_w'] * self.loss_gen_recon_x_b + \ |
|
self.config['recon_kl_w'] * self.loss_gen_recon_kl_b + \ |
|
self.config['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ |
|
self.config['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ |
|
self.config['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ |
|
self.config['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ |
|
self.config['vgg_w'] * self.loss_gen_vgg_a + \ |
|
self.config['vgg_w'] * self.loss_gen_vgg_b |
|
self.loss_gen_total.backward() |
|
self.gen_opt.step() |
|
self.x_a_recon, self.x_ab, self.x_aba = x_a_recon, x_ab, x_aba |
|
self.x_b_recon, self.x_ba, self.x_bab = x_b_recon, x_ba, x_bab |
|
|
|
def compute_vgg_loss(self, vgg, img, target): |
|
img_vgg = vgg_preprocess(img) |
|
target_vgg = vgg_preprocess(target) |
|
img_fea = vgg(img_vgg) |
|
target_fea = vgg(target_vgg) |
|
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) |
|
|
|
def dis_update(self, x_a, x_b): |
|
self.dis_opt.zero_grad() |
|
|
|
h_a, n_a = self.gen_a.encode(x_a) |
|
h_b, n_b = self.gen_b.encode(x_b) |
|
|
|
x_ba = self.gen_a.decode(h_b + n_b) |
|
x_ab = self.gen_b.decode(h_a + n_a) |
|
|
|
self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) |
|
self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) |
|
self.loss_dis_total = self.config['gan_w'] * self.loss_dis_a + self.config['gan_w'] * self.loss_dis_b |
|
self.loss_dis_total.backward() |
|
self.dis_opt.step() |
|
|
|
def get_current_errors(self): |
|
D_A = self.loss_dis_a.data[0] |
|
G_A = self.loss_gen_adv_a.data[0] |
|
kl_A = self.loss_gen_recon_kl_a.data[0] |
|
Cyc_A = self.loss_gen_cyc_x_a.data[0] |
|
D_B = self.loss_dis_b.data[0] |
|
G_B = self.loss_gen_adv_b.data[0] |
|
kl_B = self.loss_gen_recon_kl_b.data[0] |
|
Cyc_B = self.loss_gen_cyc_x_b.data[0] |
|
if self.config['vgg_w'] > 0: |
|
vgg_A = self.loss_gen_vgg_a |
|
vgg_B = self.loss_gen_vgg_b |
|
return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('kl_A', kl_A), ('vgg_A', vgg_A), |
|
('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('kl_B', kl_B), ('vgg_B', vgg_B)]) |
|
else: |
|
return OrderedDict([('D_A', D_A), ('G_A', G_A), ('kl_A', kl_A), ('Cyc_A', Cyc_A), |
|
('D_B', D_B), ('G_B', G_B), ('kl_B', kl_B), ('Cyc_B', Cyc_B)]) |
|
|
|
def get_current_visuals(self): |
|
real_A = util.tensor2im(self.real_A.data) |
|
recon_A = util.tensor2im(self.x_a_recon.data) |
|
A_B = util.tensor2im(self.x_ab.data) |
|
ABA = util.tensor2im(self.x_aba.data) |
|
real_B = util.tensor2im(self.real_B.data) |
|
recon_B = util.tensor2im(self.x_b_recon.data) |
|
B_A = util.tensor2im(self.x_ba.data) |
|
BAB = util.tensor2im(self.x_b_recon.data) |
|
return OrderedDict([('real_A', real_A), ('A_B', A_B), ('recon_A', recon_A), ('ABA', ABA), |
|
('real_B', real_B), ('B_A', B_A), ('recon_B', recon_B), ('BAB', BAB)]) |
|
|
|
def save(self, label): |
|
self.save_network(self.gen_a, 'G_A', label, self.gpu_ids) |
|
self.save_network(self.dis_a, 'D_A', label, self.gpu_ids) |
|
self.save_network(self.gen_b, 'G_B', label, self.gpu_ids) |
|
self.save_network(self.dis_b, 'D_B', label, self.gpu_ids) |
|
|
|
def update_learning_rate(self): |
|
lrd = self.config['lr'] / self.opt.niter_decay |
|
lr = self.old_lr - lrd |
|
for param_group in self.gen_a.param_groups: |
|
param_group['lr'] = lr |
|
for param_group in self.gen_b.param_groups: |
|
param_group['lr'] = lr |
|
for param_group in self.dis_a.param_groups: |
|
param_group['lr'] = lr |
|
for param_group in self.dis_b.param_groups: |
|
param_group['lr'] = lr |
|
|
|
print('update learning rate: %f -> %f' % (self.old_lr, lr)) |
|
self.old_lr = lr |