import math import numpy as np import torch from TTS.utils.io import load_fsspec from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.upsample import ConvUpsample class ParallelWaveganGenerator(torch.nn.Module): """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf. It is similar to WaveNet with no causal convolution. It is conditioned on an aux feature (spectrogram) to generate an output waveform from an input noise. """ # pylint: disable=dangerous-default-value def __init__( self, in_channels=1, out_channels=1, kernel_size=3, num_res_blocks=30, stacks=3, res_channels=64, gate_channels=128, skip_channels=64, aux_channels=80, dropout=0.0, bias=True, use_weight_norm=True, upsample_factors=[4, 4, 4, 4], inference_padding=2, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.aux_channels = aux_channels self.num_res_blocks = num_res_blocks self.stacks = stacks self.kernel_size = kernel_size self.upsample_factors = upsample_factors self.upsample_scale = np.prod(upsample_factors) self.inference_padding = inference_padding self.use_weight_norm = use_weight_norm # check the number of layers and stacks assert num_res_blocks % stacks == 0 layers_per_stack = num_res_blocks // stacks # define first convolution self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True) # define conv + upsampling network self.upsample_net = ConvUpsample(upsample_factors=upsample_factors) # define residual blocks self.conv_layers = torch.nn.ModuleList() for layer in range(num_res_blocks): dilation = 2 ** (layer % layers_per_stack) conv = ResidualBlock( kernel_size=kernel_size, res_channels=res_channels, gate_channels=gate_channels, skip_channels=skip_channels, aux_channels=aux_channels, dilation=dilation, dropout=dropout, bias=bias, ) self.conv_layers += [conv] # define output layers self.last_conv_layers = torch.nn.ModuleList( [ torch.nn.ReLU(inplace=True), torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True), torch.nn.ReLU(inplace=True), torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True), ] ) # apply weight norm if use_weight_norm: self.apply_weight_norm() def forward(self, c): """ c: (B, C ,T'). o: Output tensor (B, out_channels, T) """ # random noise x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale]) x = x.to(self.first_conv.bias.device) # perform upsampling if c is not None and self.upsample_net is not None: c = self.upsample_net(c) assert ( c.shape[-1] == x.shape[-1] ), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}" # encode to hidden representation x = self.first_conv(x) skips = 0 for f in self.conv_layers: x, h = f(x, c) skips += h skips *= math.sqrt(1.0 / len(self.conv_layers)) # apply final layers x = skips for f in self.last_conv_layers: x = f(x) return x @torch.no_grad() def inference(self, c): c = c.to(self.first_conv.weight.device) c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") return self.forward(c) def remove_weight_norm(self): def _remove_weight_norm(m): try: # print(f"Weight norm is removed from {m}.") torch.nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(_remove_weight_norm) def apply_weight_norm(self): def _apply_weight_norm(m): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): torch.nn.utils.weight_norm(m) # print(f"Weight norm is applied to {m}.") self.apply(_apply_weight_norm) @staticmethod def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): assert layers % stacks == 0 layers_per_cycle = layers // stacks dilations = [dilation(i % layers_per_cycle) for i in range(layers)] return (kernel_size - 1) * sum(dilations) + 1 @property def receptive_field_size(self): return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) def load_checkpoint( self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training if self.use_weight_norm: self.remove_weight_norm()