Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright 2020 Tomoki Hayashi | |
# MIT License (https://opensource.org/licenses/MIT) | |
"""MelGAN Modules.""" | |
import logging | |
import numpy as np | |
import torch | |
from modules.base import BaseModule | |
class MelGANDiscriminator(BaseModule): | |
"""MelGAN discriminator module.""" | |
def __init__( | |
self, | |
in_channels=1, | |
out_channels=1, | |
kernel_sizes=[5, 3], | |
channels=16, | |
max_downsample_channels=1024, | |
bias=True, | |
downsample_scales=[4, 4, 4, 4], | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"negative_slope": 0.2}, | |
pad="ReflectionPad1d", | |
pad_params={}, | |
): | |
"""Initilize MelGAN discriminator module. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer, | |
and the first and the second kernel sizes will be used for the last two layers. | |
For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15, | |
the last two layers' kernel size will be 5 and 3, respectively. | |
channels (int): Initial number of channels for conv layer. | |
max_downsample_channels (int): Maximum number of channels for downsampling layers. | |
bias (bool): Whether to add bias parameter in convolution layers. | |
downsample_scales (list): List of downsampling scales. | |
nonlinear_activation (str): Activation function module name. | |
nonlinear_activation_params (dict): Hyperparameters for activation function. | |
pad (str): Padding function module name before dilated convolution layer. | |
pad_params (dict): Hyperparameters for padding function. | |
""" | |
super(MelGANDiscriminator, self).__init__() | |
self.layers = torch.nn.ModuleList() | |
# check kernel size is valid | |
assert len(kernel_sizes) == 2 | |
assert kernel_sizes[0] % 2 == 1 | |
assert kernel_sizes[1] % 2 == 1 | |
# add first layer | |
self.layers += [ | |
torch.nn.Sequential( | |
getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params), | |
torch.nn.Conv1d( | |
in_channels, channels, np.prod(kernel_sizes), bias=bias | |
), | |
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), | |
) | |
] | |
# add downsample layers | |
in_chs = channels | |
for downsample_scale in downsample_scales: | |
out_chs = min(in_chs * downsample_scale, max_downsample_channels) | |
self.layers += [ | |
torch.nn.Sequential( | |
torch.nn.Conv1d( | |
in_chs, | |
out_chs, | |
kernel_size=downsample_scale * 10 + 1, | |
stride=downsample_scale, | |
padding=downsample_scale * 5, | |
groups=in_chs // 4, | |
bias=bias, | |
), | |
getattr(torch.nn, nonlinear_activation)( | |
**nonlinear_activation_params | |
), | |
) | |
] | |
in_chs = out_chs | |
# add final layers | |
out_chs = min(in_chs * 2, max_downsample_channels) | |
self.layers += [ | |
torch.nn.Sequential( | |
torch.nn.Conv1d( | |
in_chs, | |
out_chs, | |
kernel_sizes[0], | |
padding=(kernel_sizes[0] - 1) // 2, | |
bias=bias, | |
), | |
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), | |
) | |
] | |
self.layers += [ | |
torch.nn.Conv1d( | |
out_chs, | |
out_channels, | |
kernel_sizes[1], | |
padding=(kernel_sizes[1] - 1) // 2, | |
bias=bias, | |
), | |
] | |
def forward(self, x): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Input noise signal (B, 1, T). | |
Returns: | |
List: List of output tensors of each layer. | |
""" | |
outs = [] | |
for f in self.layers: | |
x = f(x) | |
outs += [x] | |
return outs | |
class MelGANMultiScaleDiscriminator(BaseModule): | |
"""MelGAN multi-scale discriminator module.""" | |
def __init__( | |
self, | |
in_channels=1, | |
out_channels=1, | |
scales=3, | |
downsample_pooling="AvgPool1d", | |
# follow the official implementation setting | |
downsample_pooling_params={ | |
"kernel_size": 4, | |
"stride": 2, | |
"padding": 1, | |
"count_include_pad": False, | |
}, | |
kernel_sizes=[5, 3], | |
channels=16, | |
max_downsample_channels=1024, | |
bias=True, | |
downsample_scales=[4, 4, 4, 4], | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"negative_slope": 0.2}, | |
pad="ReflectionPad1d", | |
pad_params={}, | |
use_weight_norm=True, | |
): | |
"""Initilize MelGAN multi-scale discriminator module. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
scales (int): Number of multi-scales. | |
downsample_pooling (str): Pooling module name for downsampling of the inputs. | |
downsample_pooling_params (dict): Parameters for the above pooling module. | |
kernel_sizes (list): List of two kernel sizes. The sum will be used for the first conv layer, | |
and the first and the second kernel sizes will be used for the last two layers. | |
channels (int): Initial number of channels for conv layer. | |
max_downsample_channels (int): Maximum number of channels for downsampling layers. | |
bias (bool): Whether to add bias parameter in convolution layers. | |
downsample_scales (list): List of downsampling scales. | |
nonlinear_activation (str): Activation function module name. | |
nonlinear_activation_params (dict): Hyperparameters for activation function. | |
pad (str): Padding function module name before dilated convolution layer. | |
pad_params (dict): Hyperparameters for padding function. | |
use_causal_conv (bool): Whether to use causal convolution. | |
""" | |
super(MelGANMultiScaleDiscriminator, self).__init__() | |
self.discriminators = torch.nn.ModuleList() | |
# add discriminators | |
for _ in range(scales): | |
self.discriminators += [ | |
MelGANDiscriminator( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_sizes=kernel_sizes, | |
channels=channels, | |
max_downsample_channels=max_downsample_channels, | |
bias=bias, | |
downsample_scales=downsample_scales, | |
nonlinear_activation=nonlinear_activation, | |
nonlinear_activation_params=nonlinear_activation_params, | |
pad=pad, | |
pad_params=pad_params, | |
) | |
] | |
self.pooling = getattr(torch.nn, downsample_pooling)( | |
**downsample_pooling_params | |
) | |
# apply weight norm | |
if use_weight_norm: | |
self.apply_weight_norm() | |
# reset parameters | |
self.reset_parameters() | |
def forward(self, x): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Input noise signal (B, 1, T). | |
Returns: | |
List: List of list of each discriminator outputs, which consists of each layer output tensors. | |
""" | |
outs = [] | |
for f in self.discriminators: | |
outs += [f(x)] | |
x = self.pooling(x) | |
return outs | |
def remove_weight_norm(self): | |
"""Remove weight normalization module from all of the layers.""" | |
def _remove_weight_norm(m): | |
try: | |
logging.debug(f"Weight norm is removed from {m}.") | |
torch.nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
self.apply(_remove_weight_norm) | |
def apply_weight_norm(self): | |
"""Apply weight normalization module from all of the layers.""" | |
def _apply_weight_norm(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance( | |
m, torch.nn.ConvTranspose1d | |
): | |
torch.nn.utils.weight_norm(m) | |
logging.debug(f"Weight norm is applied to {m}.") | |
self.apply(_apply_weight_norm) | |
def reset_parameters(self): | |
"""Reset parameters. | |
This initialization follows official implementation manner. | |
https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py | |
""" | |
def _reset_parameters(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance( | |
m, torch.nn.ConvTranspose1d | |
): | |
m.weight.data.normal_(0.0, 0.02) | |
logging.debug(f"Reset parameters in {m}.") | |
self.apply(_reset_parameters) | |