|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from basicsr.utils.registry import ARCH_REGISTRY |
|
from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer |
|
|
|
|
|
class SPADEGenerator(BaseNetwork): |
|
"""Generator with SPADEResBlock""" |
|
|
|
def __init__(self, |
|
num_in_ch=3, |
|
num_feat=64, |
|
use_vae=False, |
|
z_dim=256, |
|
crop_size=512, |
|
norm_g='spectralspadesyncbatch3x3', |
|
is_train=True, |
|
init_train_phase=3): |
|
super().__init__() |
|
self.nf = num_feat |
|
self.input_nc = num_in_ch |
|
self.is_train = is_train |
|
self.train_phase = init_train_phase |
|
|
|
self.scale_ratio = 5 |
|
self.sw = crop_size // (2**self.scale_ratio) |
|
self.sh = self.sw |
|
|
|
if use_vae: |
|
|
|
self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh) |
|
else: |
|
|
|
|
|
self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1) |
|
|
|
self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g) |
|
|
|
self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g) |
|
self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g) |
|
|
|
self.ups = nn.ModuleList([ |
|
SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g), |
|
SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g), |
|
SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g), |
|
SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g) |
|
]) |
|
|
|
self.to_rgbs = nn.ModuleList([ |
|
nn.Conv2d(8 * self.nf, 3, 3, padding=1), |
|
nn.Conv2d(4 * self.nf, 3, 3, padding=1), |
|
nn.Conv2d(2 * self.nf, 3, 3, padding=1), |
|
nn.Conv2d(1 * self.nf, 3, 3, padding=1) |
|
]) |
|
|
|
self.up = nn.Upsample(scale_factor=2) |
|
|
|
def encode(self, input_tensor): |
|
""" |
|
Encode input_tensor into feature maps, can be overridden in derived classes |
|
Default: nearest downsampling of 2**5 = 32 times |
|
""" |
|
h, w = input_tensor.size()[-2:] |
|
sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio |
|
x = F.interpolate(input_tensor, size=(sh, sw)) |
|
return self.fc(x) |
|
|
|
def forward(self, x): |
|
|
|
seg = x |
|
|
|
x = self.encode(x) |
|
x = self.head_0(x, seg) |
|
|
|
x = self.up(x) |
|
x = self.g_middle_0(x, seg) |
|
x = self.g_middle_1(x, seg) |
|
|
|
if self.is_train: |
|
phase = self.train_phase + 1 |
|
else: |
|
phase = len(self.to_rgbs) |
|
|
|
for i in range(phase): |
|
x = self.up(x) |
|
x = self.ups[i](x, seg) |
|
|
|
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1)) |
|
x = torch.tanh(x) |
|
|
|
return x |
|
|
|
def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'): |
|
""" |
|
A helper class for subspace visualization. Input and seg are different images. |
|
For the first n levels (including encoder) we use input, for the rest we use seg. |
|
|
|
If mode = 'progressive', the output's like: AAABBB |
|
If mode = 'one_plug', the output's like: AAABAA |
|
If mode = 'one_ablate', the output's like: BBBABB |
|
""" |
|
|
|
if seg is None: |
|
return self.forward(input_x) |
|
|
|
if self.is_train: |
|
phase = self.train_phase + 1 |
|
else: |
|
phase = len(self.to_rgbs) |
|
|
|
if mode == 'progressive': |
|
n = max(min(n, 4 + phase), 0) |
|
guide_list = [input_x] * n + [seg] * (4 + phase - n) |
|
elif mode == 'one_plug': |
|
n = max(min(n, 4 + phase - 1), 0) |
|
guide_list = [seg] * (4 + phase) |
|
guide_list[n] = input_x |
|
elif mode == 'one_ablate': |
|
if n > 3 + phase: |
|
return self.forward(input_x) |
|
guide_list = [input_x] * (4 + phase) |
|
guide_list[n] = seg |
|
|
|
x = self.encode(guide_list[0]) |
|
x = self.head_0(x, guide_list[1]) |
|
|
|
x = self.up(x) |
|
x = self.g_middle_0(x, guide_list[2]) |
|
x = self.g_middle_1(x, guide_list[3]) |
|
|
|
for i in range(phase): |
|
x = self.up(x) |
|
x = self.ups[i](x, guide_list[4 + i]) |
|
|
|
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1)) |
|
x = torch.tanh(x) |
|
|
|
return x |
|
|
|
|
|
@ARCH_REGISTRY.register() |
|
class HiFaceGAN(SPADEGenerator): |
|
""" |
|
HiFaceGAN: SPADEGenerator with a learnable feature encoder |
|
Current encoder design: LIPEncoder |
|
""" |
|
|
|
def __init__(self, |
|
num_in_ch=3, |
|
num_feat=64, |
|
use_vae=False, |
|
z_dim=256, |
|
crop_size=512, |
|
norm_g='spectralspadesyncbatch3x3', |
|
is_train=True, |
|
init_train_phase=3): |
|
super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase) |
|
self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio) |
|
|
|
def encode(self, input_tensor): |
|
return self.lip_encoder(input_tensor) |
|
|
|
|
|
@ARCH_REGISTRY.register() |
|
class HiFaceGANDiscriminator(BaseNetwork): |
|
""" |
|
Inspired by pix2pixHD multiscale discriminator. |
|
Args: |
|
num_in_ch (int): Channel number of inputs. Default: 3. |
|
num_out_ch (int): Channel number of outputs. Default: 3. |
|
conditional_d (bool): Whether use conditional discriminator. |
|
Default: True. |
|
num_d (int): Number of Multiscale discriminators. Default: 3. |
|
n_layers_d (int): Number of downsample layers in each D. Default: 4. |
|
num_feat (int): Channel number of base intermediate features. |
|
Default: 64. |
|
norm_d (str): String to determine normalization layers in D. |
|
Choices: [spectral][instance/batch/syncbatch] |
|
Default: 'spectralinstance'. |
|
keep_features (bool): Keep intermediate features for matching loss, etc. |
|
Default: True. |
|
""" |
|
|
|
def __init__(self, |
|
num_in_ch=3, |
|
num_out_ch=3, |
|
conditional_d=True, |
|
num_d=2, |
|
n_layers_d=4, |
|
num_feat=64, |
|
norm_d='spectralinstance', |
|
keep_features=True): |
|
super().__init__() |
|
self.num_d = num_d |
|
|
|
input_nc = num_in_ch |
|
if conditional_d: |
|
input_nc += num_out_ch |
|
|
|
for i in range(num_d): |
|
subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features) |
|
self.add_module(f'discriminator_{i}', subnet_d) |
|
|
|
def downsample(self, x): |
|
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False) |
|
|
|
|
|
|
|
def forward(self, x): |
|
result = [] |
|
for _, _net_d in self.named_children(): |
|
out = _net_d(x) |
|
result.append(out) |
|
x = self.downsample(x) |
|
|
|
return result |
|
|
|
|
|
class NLayerDiscriminator(BaseNetwork): |
|
"""Defines the PatchGAN discriminator with the specified arguments.""" |
|
|
|
def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features): |
|
super().__init__() |
|
kw = 4 |
|
padw = int(np.ceil((kw - 1.0) / 2)) |
|
nf = num_feat |
|
self.keep_features = keep_features |
|
|
|
norm_layer = get_nonspade_norm_layer(norm_d) |
|
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]] |
|
|
|
for n in range(1, n_layers_d): |
|
nf_prev = nf |
|
nf = min(nf * 2, 512) |
|
stride = 1 if n == n_layers_d - 1 else 2 |
|
sequence += [[ |
|
norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)), |
|
nn.LeakyReLU(0.2, False) |
|
]] |
|
|
|
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] |
|
|
|
|
|
for n in range(len(sequence)): |
|
self.add_module('model' + str(n), nn.Sequential(*sequence[n])) |
|
|
|
def forward(self, x): |
|
results = [x] |
|
for submodel in self.children(): |
|
intermediate_output = submodel(results[-1]) |
|
results.append(intermediate_output) |
|
|
|
if self.keep_features: |
|
return results[1:] |
|
else: |
|
return results[-1] |
|
|