from typing import Callable, Sequence, Type, Union import numpy as np import torch import torch.nn as nn ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]] class FeedForwardModule(nn.Module): def __init__(self) -> None: super().__init__() self.net = None def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class Residual(nn.Module): def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module def forward(self, x: torch.Tensor) -> torch.Tensor: return self.module(x) + x class DilatedConvolutionalUnit(FeedForwardModule): def __init__( self, hidden_dim: int, dilation: int, kernel_size: int, activation: ModuleFactory, normalization: Callable[[nn.Module], nn.Module] = lambda x: x) -> None: super().__init__() self.net = nn.Sequential( activation(), normalization( nn.Conv1d( in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size, dilation=dilation, padding=((kernel_size - 1) * dilation) // 2, )), activation(), nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1), ) class UpsamplingUnit(FeedForwardModule): def __init__( self, input_dim: int, output_dim: int, stride: int, activation: ModuleFactory, normalization: Callable[[nn.Module], nn.Module] = lambda x: x) -> None: super().__init__() self.net = nn.Sequential( activation(), normalization( nn.ConvTranspose1d( in_channels=input_dim, out_channels=output_dim, kernel_size=2 * stride, stride=stride, padding=stride // 2+ stride % 2, output_padding=1 if stride % 2 != 0 else 0 ))) class DownsamplingUnit(FeedForwardModule): def __init__( self, input_dim: int, output_dim: int, stride: int, activation: ModuleFactory, normalization: Callable[[nn.Module], nn.Module] = lambda x: x) -> None: super().__init__() self.net = nn.Sequential( activation(), normalization( nn.Conv1d( in_channels=input_dim, out_channels=output_dim, kernel_size=2 * stride, stride=stride, padding= stride // 2+ stride % 2, ))) class DilatedResidualEncoder(FeedForwardModule): def __init__( self, capacity: int, dilated_unit: Type[DilatedConvolutionalUnit], downsampling_unit: Type[DownsamplingUnit], ratios: Sequence[int], dilations: Union[Sequence[int], Sequence[Sequence[int]]], pre_network_conv: Type[nn.Conv1d], post_network_conv: Type[nn.Conv1d], normalization: Callable[[nn.Module], nn.Module] = lambda x: x) -> None: super().__init__() channels = capacity * 2**np.arange(len(ratios) + 1) dilations_list = self.normalize_dilations(dilations, ratios) net = [normalization(pre_network_conv(out_channels=channels[0]))] for ratio, dilations, input_dim, output_dim in zip( ratios, dilations_list, channels[:-1], channels[1:]): for dilation in dilations: net.append(Residual(dilated_unit(input_dim, dilation))) net.append(downsampling_unit(input_dim, output_dim, ratio)) net.append(post_network_conv(in_channels=output_dim)) self.net = nn.Sequential(*net) @staticmethod def normalize_dilations(dilations: Union[Sequence[int], Sequence[Sequence[int]]], ratios: Sequence[int]): if isinstance(dilations[0], int): dilations = [dilations for _ in ratios] return dilations class DilatedResidualDecoder(FeedForwardModule): def __init__( self, capacity: int, dilated_unit: Type[DilatedConvolutionalUnit], upsampling_unit: Type[UpsamplingUnit], ratios: Sequence[int], dilations: Union[Sequence[int], Sequence[Sequence[int]]], pre_network_conv: Type[nn.Conv1d], post_network_conv: Type[nn.Conv1d], normalization: Callable[[nn.Module], nn.Module] = lambda x: x) -> None: super().__init__() channels = capacity * 2**np.arange(len(ratios) + 1) channels = channels[::-1] dilations_list = self.normalize_dilations(dilations, ratios) dilations_list = dilations_list[::-1] net = [pre_network_conv(out_channels=channels[0])] for ratio, dilations, input_dim, output_dim in zip( ratios, dilations_list, channels[:-1], channels[1:]): net.append(upsampling_unit(input_dim, output_dim, ratio)) for dilation in dilations: net.append(Residual(dilated_unit(output_dim, dilation))) net.append(normalization(post_network_conv(in_channels=output_dim))) self.net = nn.Sequential(*net) @staticmethod def normalize_dilations(dilations: Union[Sequence[int], Sequence[Sequence[int]]], ratios: Sequence[int]): if isinstance(dilations[0], int): dilations = [dilations for _ in ratios] return dilations