Spaces:
Running
on
T4
Running
on
T4
# temporal patch GAN to maintain the temporal consecutive of the flows | |
import torch | |
import torch.nn as nn | |
from .BaseNetwork import BaseNetwork | |
class Discriminator(BaseNetwork): | |
def __init__(self, in_channels, conv_type, dist_cnum, use_sigmoid=False, use_spectral_norm=True, init_weights=True): | |
""" | |
Args: | |
in_channels: The input channels of the discriminator | |
use_sigmoid: Whether to use sigmoid for the base network (true for the nsgan) | |
use_spectral_norm: The usage of the spectral norm: always be true for the stability of GAN | |
init_weights: always be True | |
""" | |
super(Discriminator, self).__init__(conv_type) | |
self.use_sigmoid = use_sigmoid | |
nf = dist_cnum | |
self.conv = nn.Sequential( | |
spectral_norm( | |
nn.Conv3d(in_channels=in_channels, out_channels=nf * 1, kernel_size=(3, 5, 5), stride=(1, 2, 2), | |
padding=(1, 2, 2), | |
bias=not use_spectral_norm), use_spectral_norm), | |
nn.LeakyReLU(0.2, inplace=True), | |
spectral_norm( | |
nn.Conv3d(in_channels=nf * 1, out_channels=nf * 2, kernel_size=(3, 5, 5), stride=(1, 2, 2), | |
padding=(1, 2, 2), | |
bias=not use_spectral_norm), use_spectral_norm), | |
nn.LeakyReLU(0.2, inplace=True), | |
spectral_norm( | |
nn.Conv3d(in_channels=nf * 2, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), | |
padding=(1, 2, 2), | |
bias=not use_spectral_norm), use_spectral_norm), | |
nn.LeakyReLU(0.2, inplace=True), | |
spectral_norm( | |
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), | |
padding=(1, 2, 2), | |
bias=not use_spectral_norm), use_spectral_norm), | |
nn.LeakyReLU(0.2, inplace=True), | |
spectral_norm( | |
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), | |
padding=(1, 2, 2), | |
bias=not use_spectral_norm), use_spectral_norm), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), | |
padding=(1, 2, 2)) | |
) | |
if init_weights: | |
self.init_weights() | |
def forward(self, xs, t): | |
""" | |
Args: | |
xs: Input feature, with shape of [bt, c, h, w] | |
Returns: The discriminative map from the GAN | |
""" | |
bt, c, h, w = xs.shape | |
b = bt // t | |
xs = xs.view(b, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous() | |
feat = self.conv(xs) | |
if self.use_sigmoid: | |
feat = torch.sigmoid(feat) | |
out = torch.transpose(feat, 1, 2) # [b, t, c, h, w] | |
return out | |
def spectral_norm(module, mode=True): | |
if mode: | |
return nn.utils.spectral_norm(module) | |
return module | |