Spaces:
Running
Running
import torch | |
from modules.base import BaseModule | |
from modules.linear_modulation import FeatureWiseAffine | |
from modules.interpolation import InterpolationBlock | |
from modules.layers import Conv1dWithInitialization | |
class BasicModulationBlock(BaseModule): | |
""" | |
Linear modulation part of UBlock, represented by sequence of the following layers: | |
- Feature-wise Affine | |
- LReLU | |
- 3x1 Conv | |
""" | |
def __init__(self, n_channels, dilation): | |
super(BasicModulationBlock, self).__init__() | |
self.featurewise_affine = FeatureWiseAffine() | |
self.leaky_relu = torch.nn.LeakyReLU(0.2) | |
self.convolution = Conv1dWithInitialization( | |
in_channels=n_channels, | |
out_channels=n_channels, | |
kernel_size=3, | |
stride=1, | |
padding=dilation, | |
dilation=dilation | |
) | |
def forward(self, x, scale, shift): | |
outputs = self.featurewise_affine(x, scale, shift) | |
outputs = self.leaky_relu(outputs) | |
outputs = self.convolution(outputs) | |
return outputs | |
class UpsamplingBlock(BaseModule): | |
def __init__(self, in_channels, out_channels, factor, dilations): | |
super(UpsamplingBlock, self).__init__() | |
self.first_block_main_branch = torch.nn.ModuleDict({ | |
'upsampling': torch.nn.Sequential(*[ | |
torch.nn.LeakyReLU(0.2), | |
InterpolationBlock( | |
scale_factor=factor, | |
mode='linear', | |
align_corners=False | |
), | |
Conv1dWithInitialization( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=dilations[0], | |
dilation=dilations[0] | |
), | |
torch.nn.LeakyReLU(0.2) | |
]), | |
'modulation': BasicModulationBlock( | |
out_channels, dilation=dilations[1] | |
) | |
}) | |
self.first_block_residual_branch = torch.nn.Sequential(*[ | |
InterpolationBlock( | |
scale_factor=factor, | |
mode='linear', | |
align_corners=False | |
), | |
Conv1dWithInitialization( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1 | |
) | |
]) | |
self.second_block_main_branch = torch.nn.ModuleDict({ | |
f'modulation_{idx}': BasicModulationBlock( | |
out_channels, dilation=dilations[2 + idx] | |
) for idx in range(2) | |
}) | |
def forward(self, x, scale, shift): | |
# First upsampling residual block | |
outputs = self.first_block_main_branch['upsampling'](x) | |
outputs = self.first_block_main_branch['modulation'](outputs, scale, shift) | |
outputs = outputs + self.first_block_residual_branch(x) | |
# Second residual block | |
residual = self.second_block_main_branch['modulation_0'](outputs, scale, shift) | |
outputs = outputs + self.second_block_main_branch['modulation_1'](residual, scale, shift) | |
return outputs | |