|
|
|
|
|
|
|
import numpy as np
|
|
import math
|
|
import functools
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import init
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
from torch.nn import Parameter as P
|
|
from .transformer import Transformer
|
|
from . import BigGAN_layers as layers
|
|
from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
|
|
from util.util import to_device, load_network
|
|
from .networks import init_weights
|
|
from params import *
|
|
|
|
|
|
|
|
from models.blocks import LinearBlock, Conv2dBlock, ResBlocks, ActFirstResBlock
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'):
|
|
super(Decoder, self).__init__()
|
|
|
|
self.model = []
|
|
self.model += [ResBlocks(n_res, dim, res_norm,
|
|
activ, pad_type=pad_type)]
|
|
for i in range(ups):
|
|
self.model += [nn.Upsample(scale_factor=2),
|
|
Conv2dBlock(dim, dim // 2, 5, 1, 2,
|
|
norm='in',
|
|
activation=activ,
|
|
pad_type=pad_type)]
|
|
dim //= 2
|
|
self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3,
|
|
norm='none',
|
|
activation='tanh',
|
|
pad_type=pad_type)]
|
|
self.model = nn.Sequential(*self.model)
|
|
|
|
def forward(self, x):
|
|
y = self.model(x)
|
|
|
|
return y
|
|
|
|
|
|
|
|
def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
|
|
arch = {}
|
|
arch[512] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
|
|
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
|
|
'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],
|
|
'resolution': [8, 16, 32, 64, 128, 256, 512],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 10)}}
|
|
arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]],
|
|
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]],
|
|
'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],
|
|
'resolution': [8, 16, 32, 64, 128, 256],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 9)}}
|
|
arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
|
|
'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
|
|
'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],
|
|
'resolution': [8, 16, 32, 64, 128],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 8)}}
|
|
arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
|
|
'out_channels': [ch * item for item in [16, 8, 4, 2]],
|
|
'upsample': [(2, 2), (2, 2), (2, 2), (2, 2)],
|
|
'resolution': [8, 16, 32, 64],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 7)}}
|
|
|
|
arch[63] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
|
|
'out_channels': [ch * item for item in [16, 8, 4, 2]],
|
|
'upsample': [(2, 2), (2, 2), (2, 2), (2,1)],
|
|
'resolution': [8, 16, 32, 64],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 7)},
|
|
'kernel1': [3, 3, 3, 3],
|
|
'kernel2': [3, 3, 1, 1]
|
|
}
|
|
|
|
arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
|
|
'out_channels': [ch * item for item in [4, 4, 4]],
|
|
'upsample': [(2, 2), (2, 2), (2, 2)],
|
|
'resolution': [8, 16, 32],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 6)}}
|
|
|
|
arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
|
|
'out_channels': [ch * item for item in [4, 4, 4]],
|
|
'upsample': [(2, 2), (2, 2), (2, 2)],
|
|
'resolution': [8, 16, 32],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 6)},
|
|
'kernel1': [3, 3, 3],
|
|
'kernel2': [3, 3, 1]
|
|
}
|
|
|
|
arch[129] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
|
|
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
|
|
'upsample': [(2,2), (2,2), (2,2), (2,2), (2,2), (1,2), (1,2)],
|
|
'resolution': [8, 16, 32, 64, 128, 256, 512],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 10)}}
|
|
|
|
arch[33] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
|
|
'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
|
|
'upsample': [(2,2), (2,2), (2,2), (1,2), (1,2)],
|
|
'resolution': [8, 16, 32, 64, 128],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 8)}}
|
|
|
|
arch[31] = {'in_channels': [ch * item for item in [16, 16, 4, 2]],
|
|
'out_channels': [ch * item for item in [16, 4, 2, 1]],
|
|
'upsample': [(2,2), (2,2), (2,2), (1,2)],
|
|
'resolution': [8, 16, 32, 64],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 7)},
|
|
'kernel1':[3, 3, 3, 3],
|
|
'kernel2': [3, 1, 1, 1]}
|
|
|
|
arch[16] = {'in_channels': [ch * item for item in [8, 4, 2]],
|
|
'out_channels': [ch * item for item in [4, 2, 1]],
|
|
'upsample': [(2,2), (2,2), (2,1)],
|
|
'resolution': [8, 16, 16],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 6)},
|
|
'kernel1':[3, 3, 3],
|
|
'kernel2': [3, 3, 1]}
|
|
|
|
arch[17] = {'in_channels': [ch * item for item in [8, 4, 2]],
|
|
'out_channels': [ch * item for item in [4, 2, 1]],
|
|
'upsample': [(2,2), (2,2), (2,1)],
|
|
'resolution': [8, 16, 16],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 6)},
|
|
'kernel1':[3, 3, 3],
|
|
'kernel2': [3, 3, 1]}
|
|
|
|
arch[20] = {'in_channels': [ch * item for item in [8, 4, 2]],
|
|
'out_channels': [ch * item for item in [4, 2, 1]],
|
|
'upsample': [(2,2), (2,2), (2,1)],
|
|
'resolution': [8, 16, 16],
|
|
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
|
for i in range(3, 6)},
|
|
'kernel1':[3, 3, 3],
|
|
'kernel2': [3, 1, 1]}
|
|
|
|
return arch
|
|
|
|
|
|
class Generator(nn.Module):
|
|
def __init__(self, G_ch=64, dim_z=128, bottom_width=4, bottom_height=4,resolution=128,
|
|
G_kernel_size=3, G_attn='64', n_classes=1000,
|
|
num_G_SVs=1, num_G_SV_itrs=1,
|
|
G_shared=True, shared_dim=0, no_hier=False,
|
|
cross_replica=False, mybn=False,
|
|
G_activation=nn.ReLU(inplace=False),
|
|
BN_eps=1e-5, SN_eps=1e-12, G_fp16=False,
|
|
G_init='ortho', skip_init=False,
|
|
G_param='SN', norm_style='bn',gpu_ids=[], bn_linear='embed', input_nc=3,
|
|
one_hot=False, first_layer=False, one_hot_k=1,
|
|
**kwargs):
|
|
super(Generator, self).__init__()
|
|
self.name = 'G'
|
|
|
|
self.first_layer = first_layer
|
|
|
|
self.gpu_ids = gpu_ids
|
|
|
|
self.one_hot = one_hot
|
|
|
|
self.one_hot_k = one_hot_k
|
|
|
|
self.ch = G_ch
|
|
|
|
self.dim_z = dim_z
|
|
|
|
self.bottom_width = bottom_width
|
|
|
|
self.bottom_height = bottom_height
|
|
|
|
self.resolution = resolution
|
|
|
|
self.kernel_size = G_kernel_size
|
|
|
|
self.attention = G_attn
|
|
|
|
self.n_classes = n_classes
|
|
|
|
self.G_shared = G_shared
|
|
|
|
self.shared_dim = shared_dim if shared_dim > 0 else dim_z
|
|
|
|
self.hier = not no_hier
|
|
|
|
self.cross_replica = cross_replica
|
|
|
|
self.mybn = mybn
|
|
|
|
self.activation = G_activation
|
|
|
|
self.init = G_init
|
|
|
|
self.G_param = G_param
|
|
|
|
self.norm_style = norm_style
|
|
|
|
self.BN_eps = BN_eps
|
|
|
|
self.SN_eps = SN_eps
|
|
|
|
self.fp16 = G_fp16
|
|
|
|
self.arch = G_arch(self.ch, self.attention)[resolution]
|
|
self.bn_linear = bn_linear
|
|
|
|
|
|
|
|
self.linear_q = nn.Linear(512,2048*2)
|
|
|
|
self.DETR = build()
|
|
self.DEC = Decoder(res_norm = 'in')
|
|
|
|
if self.hier:
|
|
|
|
self.num_slots = len(self.arch['in_channels']) + 1
|
|
self.z_chunk_size = (self.dim_z // self.num_slots)
|
|
|
|
self.dim_z = self.z_chunk_size * self.num_slots
|
|
else:
|
|
self.num_slots = 1
|
|
self.z_chunk_size = 0
|
|
|
|
|
|
if self.G_param == 'SN':
|
|
self.which_conv = functools.partial(layers.SNConv2d,
|
|
kernel_size=3, padding=1,
|
|
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
|
eps=self.SN_eps)
|
|
self.which_linear = functools.partial(layers.SNLinear,
|
|
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
|
eps=self.SN_eps)
|
|
else:
|
|
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
|
self.which_linear = nn.Linear
|
|
|
|
|
|
|
|
if one_hot:
|
|
self.which_embedding = functools.partial(layers.SNLinear,
|
|
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
|
eps=self.SN_eps)
|
|
else:
|
|
self.which_embedding = nn.Embedding
|
|
|
|
bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
|
|
else self.which_embedding)
|
|
if self.bn_linear=='SN':
|
|
bn_linear = functools.partial(self.which_linear, bias=False)
|
|
if self.G_shared:
|
|
input_size = self.shared_dim + self.z_chunk_size
|
|
elif self.hier:
|
|
if self.first_layer:
|
|
input_size = self.z_chunk_size
|
|
else:
|
|
input_size = self.n_classes + self.z_chunk_size
|
|
self.which_bn = functools.partial(layers.ccbn,
|
|
which_linear=bn_linear,
|
|
cross_replica=self.cross_replica,
|
|
mybn=self.mybn,
|
|
input_size=input_size,
|
|
norm_style=self.norm_style,
|
|
eps=self.BN_eps)
|
|
else:
|
|
input_size = self.n_classes
|
|
self.which_bn = functools.partial(layers.bn,
|
|
cross_replica=self.cross_replica,
|
|
mybn=self.mybn,
|
|
eps=self.BN_eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
|
|
else layers.identity())
|
|
|
|
|
|
if self.first_layer:
|
|
if self.one_hot:
|
|
self.linear = self.which_linear(self.dim_z // self.num_slots + self.n_classes,
|
|
self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
|
|
else:
|
|
self.linear = self.which_linear(self.dim_z // self.num_slots + 1,
|
|
self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
|
|
if self.one_hot_k==1:
|
|
self.linear = self.which_linear((self.dim_z // self.num_slots) * self.n_classes,
|
|
self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
|
|
if self.one_hot_k>1:
|
|
self.linear = self.which_linear(self.dim_z // self.num_slots + self.n_classes*self.one_hot_k,
|
|
self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
|
|
|
|
|
|
else:
|
|
self.linear = self.which_linear(self.dim_z // self.num_slots,
|
|
self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
|
|
|
|
|
|
|
|
self.blocks = []
|
|
for index in range(len(self.arch['out_channels'])):
|
|
if 'kernel1' in self.arch.keys():
|
|
padd1 = 1 if self.arch['kernel1'][index]>1 else 0
|
|
padd2 = 1 if self.arch['kernel2'][index]>1 else 0
|
|
conv1 = functools.partial(layers.SNConv2d,
|
|
kernel_size=self.arch['kernel1'][index], padding=padd1,
|
|
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
|
eps=self.SN_eps)
|
|
conv2 = functools.partial(layers.SNConv2d,
|
|
kernel_size=self.arch['kernel2'][index], padding=padd2,
|
|
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
|
eps=self.SN_eps)
|
|
self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
|
|
out_channels=self.arch['out_channels'][index],
|
|
which_conv1=conv1,
|
|
which_conv2=conv2,
|
|
which_bn=self.which_bn,
|
|
activation=self.activation,
|
|
upsample=(functools.partial(F.interpolate,
|
|
scale_factor=self.arch['upsample'][index])
|
|
if index < len(self.arch['upsample']) else None))]]
|
|
else:
|
|
self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
|
|
out_channels=self.arch['out_channels'][index],
|
|
which_conv1=self.which_conv,
|
|
which_conv2=self.which_conv,
|
|
which_bn=self.which_bn,
|
|
activation=self.activation,
|
|
upsample=(functools.partial(F.interpolate, scale_factor=self.arch['upsample'][index])
|
|
if index < len(self.arch['upsample']) else None))]]
|
|
|
|
|
|
if self.arch['attention'][self.arch['resolution'][index]]:
|
|
print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
|
|
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
|
|
|
|
|
|
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
|
|
|
|
|
|
|
self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
|
|
cross_replica=self.cross_replica,
|
|
mybn=self.mybn),
|
|
self.activation,
|
|
self.which_conv(self.arch['out_channels'][-1], input_nc))
|
|
|
|
|
|
if not skip_init:
|
|
self = init_weights(self, G_init)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, y_ind, y):
|
|
|
|
|
|
|
|
h_all = self.DETR(x, y_ind)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h = self.linear_q(h_all)
|
|
|
|
|
|
h = h.contiguous()
|
|
|
|
|
|
if self.first_layer:
|
|
|
|
h = h.view(h.size(0), h.shape[1]*2, 4, -1)
|
|
h = h.permute(0, 3, 2, 1)
|
|
|
|
else:
|
|
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_height)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h = self.DEC(h)
|
|
return h
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def D_arch(ch=64, attention='64', input_nc=3, ksize='333333', dilation='111111'):
|
|
arch = {}
|
|
arch[256] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
|
|
'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
|
|
'downsample': [True] * 6 + [False],
|
|
'resolution': [128, 64, 32, 16, 8, 4, 4],
|
|
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
|
for i in range(2, 8)}}
|
|
arch[128] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
|
|
'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
|
'downsample': [True] * 5 + [False],
|
|
'resolution': [64, 32, 16, 8, 4, 4],
|
|
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
|
for i in range(2, 8)}}
|
|
arch[64] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]],
|
|
'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
|
|
'downsample': [True] * 4 + [False],
|
|
'resolution': [32, 16, 8, 4, 4],
|
|
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
|
for i in range(2, 7)}}
|
|
arch[63] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]],
|
|
'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
|
|
'downsample': [True] * 4 + [False],
|
|
'resolution': [32, 16, 8, 4, 4],
|
|
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
|
for i in range(2, 7)}}
|
|
arch[32] = {'in_channels': [input_nc] + [item * ch for item in [4, 4, 4]],
|
|
'out_channels': [item * ch for item in [4, 4, 4, 4]],
|
|
'downsample': [True, True, False, False],
|
|
'resolution': [16, 16, 16, 16],
|
|
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
|
for i in range(2, 6)}}
|
|
arch[129] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
|
|
'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
|
|
'downsample': [True] * 6 + [False],
|
|
'resolution': [128, 64, 32, 16, 8, 4, 4],
|
|
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
|
for i in range(2, 8)}}
|
|
arch[33] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
|
|
'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
|
'downsample': [True] * 5 + [False],
|
|
'resolution': [64, 32, 16, 8, 4, 4],
|
|
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
|
for i in range(2, 10)}}
|
|
arch[31] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
|
|
'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
|
'downsample': [True] * 5 + [False],
|
|
'resolution': [64, 32, 16, 8, 4, 4],
|
|
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
|
for i in range(2, 10)}}
|
|
arch[16] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]],
|
|
'out_channels': [item * ch for item in [1, 8, 16, 16]],
|
|
'downsample': [True] * 3 + [False],
|
|
'resolution': [16, 8, 4, 4],
|
|
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
|
for i in range(2, 5)}}
|
|
|
|
arch[17] = {'in_channels': [input_nc] + [ch * item for item in [1, 4]],
|
|
'out_channels': [item * ch for item in [1, 4, 8]],
|
|
'downsample': [True] * 3,
|
|
'resolution': [16, 8, 4],
|
|
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
|
for i in range(2, 5)}}
|
|
|
|
|
|
arch[20] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]],
|
|
'out_channels': [item * ch for item in [1, 8, 16, 16]],
|
|
'downsample': [True] * 3 + [False],
|
|
'resolution': [16, 8, 4, 4],
|
|
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
|
for i in range(2, 5)}}
|
|
return arch
|
|
|
|
|
|
class Discriminator(nn.Module):
|
|
|
|
def __init__(self, D_ch=64, D_wide=True, resolution=resolution,
|
|
D_kernel_size=3, D_attn='64', n_classes=VOCAB_SIZE,
|
|
num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
|
|
SN_eps=1e-8, output_dim=1, D_mixed_precision=False, D_fp16=False,
|
|
D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False, **kwargs):
|
|
|
|
super(Discriminator, self).__init__()
|
|
self.name = 'D'
|
|
|
|
self.gpu_ids = gpu_ids
|
|
|
|
self.one_hot = one_hot
|
|
|
|
self.ch = D_ch
|
|
|
|
self.D_wide = D_wide
|
|
|
|
self.resolution = resolution
|
|
|
|
self.kernel_size = D_kernel_size
|
|
|
|
self.attention = D_attn
|
|
|
|
self.n_classes = n_classes
|
|
|
|
self.activation = D_activation
|
|
|
|
self.init = D_init
|
|
|
|
self.D_param = D_param
|
|
|
|
self.SN_eps = SN_eps
|
|
|
|
self.fp16 = D_fp16
|
|
|
|
self.arch = D_arch(self.ch, self.attention, input_nc)[resolution]
|
|
|
|
|
|
|
|
if self.D_param == 'SN':
|
|
self.which_conv = functools.partial(layers.SNConv2d,
|
|
kernel_size=3, padding=1,
|
|
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
|
eps=self.SN_eps)
|
|
self.which_linear = functools.partial(layers.SNLinear,
|
|
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
|
eps=self.SN_eps)
|
|
self.which_embedding = functools.partial(layers.SNEmbedding,
|
|
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
|
eps=self.SN_eps)
|
|
if bn_linear=='SN':
|
|
self.which_embedding = functools.partial(layers.SNLinear,
|
|
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
|
eps=self.SN_eps)
|
|
else:
|
|
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
|
self.which_linear = nn.Linear
|
|
|
|
|
|
self.which_embedding = nn.Embedding
|
|
if one_hot:
|
|
self.which_embedding = functools.partial(layers.SNLinear,
|
|
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
|
eps=self.SN_eps)
|
|
|
|
|
|
|
|
self.blocks = []
|
|
for index in range(len(self.arch['out_channels'])):
|
|
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
|
|
out_channels=self.arch['out_channels'][index],
|
|
which_conv=self.which_conv,
|
|
wide=self.D_wide,
|
|
activation=self.activation,
|
|
preactivation=(index > 0),
|
|
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
|
|
|
|
if self.arch['attention'][self.arch['resolution'][index]]:
|
|
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
|
|
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
|
|
self.which_conv)]
|
|
|
|
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
|
|
|
|
|
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
|
|
|
|
self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
|
|
|
|
|
|
if not skip_init:
|
|
self = init_weights(self, D_init)
|
|
|
|
def forward(self, x, y=None, **kwargs):
|
|
|
|
h = x
|
|
|
|
for index, blocklist in enumerate(self.blocks):
|
|
for block in blocklist:
|
|
h = block(h)
|
|
|
|
h = torch.sum(self.activation(h), [2, 3])
|
|
|
|
out = self.linear(h)
|
|
|
|
if y is not None:
|
|
out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
|
|
return out
|
|
|
|
def return_features(self, x, y=None):
|
|
|
|
h = x
|
|
block_output = []
|
|
|
|
for index, blocklist in enumerate(self.blocks):
|
|
for block in blocklist:
|
|
h = block(h)
|
|
block_output.append(h)
|
|
|
|
|
|
return block_output
|
|
|
|
|
|
|
|
|
|
class WDiscriminator(nn.Module):
|
|
|
|
def __init__(self, D_ch=64, D_wide=True, resolution=resolution,
|
|
D_kernel_size=3, D_attn='64', n_classes=VOCAB_SIZE,
|
|
num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
|
|
SN_eps=1e-8, output_dim=NUM_WRITERS, D_mixed_precision=False, D_fp16=False,
|
|
D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False, **kwargs):
|
|
super(WDiscriminator, self).__init__()
|
|
self.name = 'D'
|
|
|
|
self.gpu_ids = gpu_ids
|
|
|
|
self.one_hot = one_hot
|
|
|
|
self.ch = D_ch
|
|
|
|
self.D_wide = D_wide
|
|
|
|
self.resolution = resolution
|
|
|
|
self.kernel_size = D_kernel_size
|
|
|
|
self.attention = D_attn
|
|
|
|
self.n_classes = n_classes
|
|
|
|
self.activation = D_activation
|
|
|
|
self.init = D_init
|
|
|
|
self.D_param = D_param
|
|
|
|
self.SN_eps = SN_eps
|
|
|
|
self.fp16 = D_fp16
|
|
|
|
self.arch = D_arch(self.ch, self.attention, input_nc)[resolution]
|
|
|
|
|
|
|
|
if self.D_param == 'SN':
|
|
self.which_conv = functools.partial(layers.SNConv2d,
|
|
kernel_size=3, padding=1,
|
|
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
|
eps=self.SN_eps)
|
|
self.which_linear = functools.partial(layers.SNLinear,
|
|
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
|
eps=self.SN_eps)
|
|
self.which_embedding = functools.partial(layers.SNEmbedding,
|
|
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
|
eps=self.SN_eps)
|
|
if bn_linear=='SN':
|
|
self.which_embedding = functools.partial(layers.SNLinear,
|
|
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
|
eps=self.SN_eps)
|
|
else:
|
|
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
|
self.which_linear = nn.Linear
|
|
|
|
|
|
self.which_embedding = nn.Embedding
|
|
if one_hot:
|
|
self.which_embedding = functools.partial(layers.SNLinear,
|
|
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
|
eps=self.SN_eps)
|
|
|
|
|
|
|
|
self.blocks = []
|
|
for index in range(len(self.arch['out_channels'])):
|
|
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
|
|
out_channels=self.arch['out_channels'][index],
|
|
which_conv=self.which_conv,
|
|
wide=self.D_wide,
|
|
activation=self.activation,
|
|
preactivation=(index > 0),
|
|
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
|
|
|
|
if self.arch['attention'][self.arch['resolution'][index]]:
|
|
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
|
|
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
|
|
self.which_conv)]
|
|
|
|
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
|
|
|
|
|
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
|
|
|
|
self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
|
|
self.cross_entropy = nn.CrossEntropyLoss()
|
|
|
|
if not skip_init:
|
|
self = init_weights(self, D_init)
|
|
|
|
def forward(self, x, y=None, **kwargs):
|
|
|
|
h = x
|
|
|
|
for index, blocklist in enumerate(self.blocks):
|
|
for block in blocklist:
|
|
h = block(h)
|
|
|
|
h = torch.sum(self.activation(h), [2, 3])
|
|
|
|
out = self.linear(h)
|
|
|
|
|
|
|
|
|
|
loss = self.cross_entropy(out, y.long())
|
|
|
|
return loss
|
|
|
|
def return_features(self, x, y=None):
|
|
|
|
h = x
|
|
block_output = []
|
|
|
|
for index, blocklist in enumerate(self.blocks):
|
|
for block in blocklist:
|
|
h = block(h)
|
|
block_output.append(h)
|
|
|
|
|
|
return block_output
|
|
|
|
class Encoder(Discriminator):
|
|
def __init__(self, opt, output_dim, **kwargs):
|
|
super(Encoder, self).__init__(**vars(opt))
|
|
self.output_layer = nn.Sequential(self.activation,
|
|
nn.Conv2d(self.arch['out_channels'][-1], output_dim, kernel_size=(4,2), padding=0, stride=2))
|
|
|
|
def forward(self, x):
|
|
|
|
h = x
|
|
|
|
for index, blocklist in enumerate(self.blocks):
|
|
for block in blocklist:
|
|
h = block(h)
|
|
out = self.output_layer(h)
|
|
return out
|
|
|
|
class BiDiscriminator(nn.Module):
|
|
def __init__(self, opt):
|
|
super(BiDiscriminator, self).__init__()
|
|
self.infer_img = Encoder(opt, output_dim=opt.nimg_features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.infer_joint = nn.Sequential(
|
|
nn.Conv2d(opt.dim_z+opt.nimg_features, 1024, 1, stride=1, bias=True),
|
|
nn.ReLU(inplace=True),
|
|
|
|
nn.Conv2d(1024, 1024, 1, stride=1, bias=True),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
self.final = nn.Conv2d(1024, 1, 1, stride=1, bias=True)
|
|
|
|
def forward(self, x, z, **kwargs):
|
|
output_x = self.infer_img(x)
|
|
|
|
if len(z.shape)==2:
|
|
z = z.unsqueeze(2).unsqueeze(2).repeat((1,1,1,output_x.shape[3]))
|
|
output_features = self.infer_joint(torch.cat([output_x, z], dim=1))
|
|
output = self.final(output_features)
|
|
return output
|
|
|
|
|
|
|
|
class G_D(nn.Module):
|
|
def __init__(self, G, D):
|
|
super(G_D, self).__init__()
|
|
self.G = G
|
|
self.D = D
|
|
|
|
def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False,
|
|
split_D=False):
|
|
|
|
with torch.set_grad_enabled(train_G):
|
|
|
|
G_z = self.G(z, self.G.shared(gy))
|
|
|
|
if self.G.fp16 and not self.D.fp16:
|
|
G_z = G_z.float()
|
|
if self.D.fp16 and not self.G.fp16:
|
|
G_z = G_z.half()
|
|
|
|
|
|
if split_D:
|
|
D_fake = self.D(G_z, gy)
|
|
if x is not None:
|
|
D_real = self.D(x, dy)
|
|
return D_fake, D_real
|
|
else:
|
|
if return_G_z:
|
|
return D_fake, G_z
|
|
else:
|
|
return D_fake
|
|
|
|
|
|
else:
|
|
D_input = torch.cat([G_z, x], 0) if x is not None else G_z
|
|
D_class = torch.cat([gy, dy], 0) if dy is not None else gy
|
|
|
|
D_out = self.D(D_input, D_class)
|
|
if x is not None:
|
|
return torch.split(D_out, [G_z.shape[0], x.shape[0]])
|
|
else:
|
|
if return_G_z:
|
|
return D_out, G_z
|
|
else:
|
|
return D_out
|
|
|
|
|