import torch from torch import nn from torch.nn.utils import weight_norm from TTS.utils.io import load_fsspec from TTS.vocoder.layers.melgan import ResidualStack class MelganGenerator(nn.Module): def __init__( self, in_channels=80, out_channels=1, proj_kernel=7, base_channels=512, upsample_factors=(8, 8, 2, 2), res_kernel=3, num_res_blocks=3, ): super().__init__() # assert model parameters assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." # setup additional model parameters base_padding = (proj_kernel - 1) // 2 act_slope = 0.2 self.inference_padding = 2 # initial layer layers = [] layers += [ nn.ReflectionPad1d(base_padding), weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)), ] # upsampling layers and residual stacks for idx, upsample_factor in enumerate(upsample_factors): layer_in_channels = base_channels // (2**idx) layer_out_channels = base_channels // (2 ** (idx + 1)) layer_filter_size = upsample_factor * 2 layer_stride = upsample_factor layer_output_padding = upsample_factor % 2 layer_padding = upsample_factor // 2 + layer_output_padding layers += [ nn.LeakyReLU(act_slope), weight_norm( nn.ConvTranspose1d( layer_in_channels, layer_out_channels, layer_filter_size, stride=layer_stride, padding=layer_padding, output_padding=layer_output_padding, bias=True, ) ), ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel), ] layers += [nn.LeakyReLU(act_slope)] # final layer layers += [ nn.ReflectionPad1d(base_padding), weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)), nn.Tanh(), ] self.layers = nn.Sequential(*layers) def forward(self, c): return self.layers(c) def inference(self, c): c = c.to(self.layers[1].weight.device) c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") return self.layers(c) def remove_weight_norm(self): for _, layer in enumerate(self.layers): if len(layer.state_dict()) != 0: try: nn.utils.remove_weight_norm(layer) except ValueError: layer.remove_weight_norm() def load_checkpoint( self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training self.remove_weight_norm()