xcodec2 / vq /blocks.py
yezhen
Initial commit
574a515
raw
history blame
6.12 kB
from typing import Callable, Sequence, Type, Union
import numpy as np
import torch
import torch.nn as nn
ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]]
class FeedForwardModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.net = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class Residual(nn.Module):
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.module(x) + x
class DilatedConvolutionalUnit(FeedForwardModule):
def __init__(
self,
hidden_dim: int,
dilation: int,
kernel_size: int,
activation: ModuleFactory,
normalization: Callable[[nn.Module],
nn.Module] = lambda x: x) -> None:
super().__init__()
self.net = nn.Sequential(
activation(),
normalization(
nn.Conv1d(
in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=kernel_size,
dilation=dilation,
padding=((kernel_size - 1) * dilation) // 2,
)),
activation(),
nn.Conv1d(in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=1),
)
class UpsamplingUnit(FeedForwardModule):
def __init__(
self,
input_dim: int,
output_dim: int,
stride: int,
activation: ModuleFactory,
normalization: Callable[[nn.Module],
nn.Module] = lambda x: x) -> None:
super().__init__()
self.net = nn.Sequential(
activation(),
normalization(
nn.ConvTranspose1d(
in_channels=input_dim,
out_channels=output_dim,
kernel_size=2 * stride,
stride=stride,
padding=stride // 2+ stride % 2,
output_padding=1 if stride % 2 != 0 else 0
)))
class DownsamplingUnit(FeedForwardModule):
def __init__(
self,
input_dim: int,
output_dim: int,
stride: int,
activation: ModuleFactory,
normalization: Callable[[nn.Module],
nn.Module] = lambda x: x) -> None:
super().__init__()
self.net = nn.Sequential(
activation(),
normalization(
nn.Conv1d(
in_channels=input_dim,
out_channels=output_dim,
kernel_size=2 * stride,
stride=stride,
padding= stride // 2+ stride % 2,
)))
class DilatedResidualEncoder(FeedForwardModule):
def __init__(
self,
capacity: int,
dilated_unit: Type[DilatedConvolutionalUnit],
downsampling_unit: Type[DownsamplingUnit],
ratios: Sequence[int],
dilations: Union[Sequence[int], Sequence[Sequence[int]]],
pre_network_conv: Type[nn.Conv1d],
post_network_conv: Type[nn.Conv1d],
normalization: Callable[[nn.Module],
nn.Module] = lambda x: x) -> None:
super().__init__()
channels = capacity * 2**np.arange(len(ratios) + 1)
dilations_list = self.normalize_dilations(dilations, ratios)
net = [normalization(pre_network_conv(out_channels=channels[0]))]
for ratio, dilations, input_dim, output_dim in zip(
ratios, dilations_list, channels[:-1], channels[1:]):
for dilation in dilations:
net.append(Residual(dilated_unit(input_dim, dilation)))
net.append(downsampling_unit(input_dim, output_dim, ratio))
net.append(post_network_conv(in_channels=output_dim))
self.net = nn.Sequential(*net)
@staticmethod
def normalize_dilations(dilations: Union[Sequence[int],
Sequence[Sequence[int]]],
ratios: Sequence[int]):
if isinstance(dilations[0], int):
dilations = [dilations for _ in ratios]
return dilations
class DilatedResidualDecoder(FeedForwardModule):
def __init__(
self,
capacity: int,
dilated_unit: Type[DilatedConvolutionalUnit],
upsampling_unit: Type[UpsamplingUnit],
ratios: Sequence[int],
dilations: Union[Sequence[int], Sequence[Sequence[int]]],
pre_network_conv: Type[nn.Conv1d],
post_network_conv: Type[nn.Conv1d],
normalization: Callable[[nn.Module],
nn.Module] = lambda x: x) -> None:
super().__init__()
channels = capacity * 2**np.arange(len(ratios) + 1)
channels = channels[::-1]
dilations_list = self.normalize_dilations(dilations, ratios)
dilations_list = dilations_list[::-1]
net = [pre_network_conv(out_channels=channels[0])]
for ratio, dilations, input_dim, output_dim in zip(
ratios, dilations_list, channels[:-1], channels[1:]):
net.append(upsampling_unit(input_dim, output_dim, ratio))
for dilation in dilations:
net.append(Residual(dilated_unit(output_dim, dilation)))
net.append(normalization(post_network_conv(in_channels=output_dim)))
self.net = nn.Sequential(*net)
@staticmethod
def normalize_dilations(dilations: Union[Sequence[int],
Sequence[Sequence[int]]],
ratios: Sequence[int]):
if isinstance(dilations[0], int):
dilations = [dilations for _ in ratios]
return dilations