import torch from modules.base import BaseModule from modules.interpolation import InterpolationBlock from modules.layers import Conv1dWithInitialization class ConvolutionBlock(BaseModule): def __init__(self, in_channels, out_channels, dilation): super(ConvolutionBlock, self).__init__() self.leaky_relu = torch.nn.LeakyReLU(0.2) self.convolution = Conv1dWithInitialization( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation ) def forward(self, x): outputs = self.leaky_relu(x) outputs = self.convolution(outputs) return outputs class DownsamplingBlock(BaseModule): def __init__(self, in_channels, out_channels, factor, dilations): super(DownsamplingBlock, self).__init__() in_sizes = [in_channels] + [out_channels for _ in range(len(dilations) - 1)] out_sizes = [out_channels for _ in range(len(in_sizes))] self.main_branch = torch.nn.Sequential(*([ InterpolationBlock( scale_factor=factor, mode='linear', align_corners=False, downsample=True ) ] + [ ConvolutionBlock(in_size, out_size, dilation) for in_size, out_size, dilation in zip(in_sizes, out_sizes, dilations) ])) self.residual_branch = torch.nn.Sequential(*[ Conv1dWithInitialization( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 ), InterpolationBlock( scale_factor=factor, mode='linear', align_corners=False, downsample=True ) ]) def forward(self, x): outputs = self.main_branch(x) outputs = outputs + self.residual_branch(x) return outputs