File size: 5,527 Bytes
8c70653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import math

import numpy as np
import torch

from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample


class ParallelWaveganGenerator(torch.nn.Module):
    """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
    It is similar to WaveNet with no causal convolution.
        It is conditioned on an aux feature (spectrogram) to generate
    an output waveform from an input noise.
    """

    # pylint: disable=dangerous-default-value
    def __init__(
        self,
        in_channels=1,
        out_channels=1,
        kernel_size=3,
        num_res_blocks=30,
        stacks=3,
        res_channels=64,
        gate_channels=128,
        skip_channels=64,
        aux_channels=80,
        dropout=0.0,
        bias=True,
        use_weight_norm=True,
        upsample_factors=[4, 4, 4, 4],
        inference_padding=2,
    ):

        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.aux_channels = aux_channels
        self.num_res_blocks = num_res_blocks
        self.stacks = stacks
        self.kernel_size = kernel_size
        self.upsample_factors = upsample_factors
        self.upsample_scale = np.prod(upsample_factors)
        self.inference_padding = inference_padding
        self.use_weight_norm = use_weight_norm

        # check the number of layers and stacks
        assert num_res_blocks % stacks == 0
        layers_per_stack = num_res_blocks // stacks

        # define first convolution
        self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True)

        # define conv + upsampling network
        self.upsample_net = ConvUpsample(upsample_factors=upsample_factors)

        # define residual blocks
        self.conv_layers = torch.nn.ModuleList()
        for layer in range(num_res_blocks):
            dilation = 2 ** (layer % layers_per_stack)
            conv = ResidualBlock(
                kernel_size=kernel_size,
                res_channels=res_channels,
                gate_channels=gate_channels,
                skip_channels=skip_channels,
                aux_channels=aux_channels,
                dilation=dilation,
                dropout=dropout,
                bias=bias,
            )
            self.conv_layers += [conv]

        # define output layers
        self.last_conv_layers = torch.nn.ModuleList(
            [
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True),
            ]
        )

        # apply weight norm
        if use_weight_norm:
            self.apply_weight_norm()

    def forward(self, c):
        """
        c: (B, C ,T').
        o: Output tensor (B, out_channels, T)
        """
        # random noise
        x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale])
        x = x.to(self.first_conv.bias.device)

        # perform upsampling
        if c is not None and self.upsample_net is not None:
            c = self.upsample_net(c)
            assert (
                c.shape[-1] == x.shape[-1]
            ), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}"

        # encode to hidden representation
        x = self.first_conv(x)
        skips = 0
        for f in self.conv_layers:
            x, h = f(x, c)
            skips += h
        skips *= math.sqrt(1.0 / len(self.conv_layers))

        # apply final layers
        x = skips
        for f in self.last_conv_layers:
            x = f(x)

        return x

    @torch.no_grad()
    def inference(self, c):
        c = c.to(self.first_conv.weight.device)
        c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
        return self.forward(c)

    def remove_weight_norm(self):
        def _remove_weight_norm(m):
            try:
                # print(f"Weight norm is removed from {m}.")
                torch.nn.utils.remove_weight_norm(m)
            except ValueError:  # this module didn't have weight norm
                return

        self.apply(_remove_weight_norm)

    def apply_weight_norm(self):
        def _apply_weight_norm(m):
            if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
                torch.nn.utils.weight_norm(m)
                # print(f"Weight norm is applied to {m}.")

        self.apply(_apply_weight_norm)

    @staticmethod
    def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
        assert layers % stacks == 0
        layers_per_cycle = layers // stacks
        dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
        return (kernel_size - 1) * sum(dilations) + 1

    @property
    def receptive_field_size(self):
        return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)

    def load_checkpoint(
        self, config, checkpoint_path, eval=False, cache=False
    ):  # pylint: disable=unused-argument, redefined-builtin
        state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
        self.load_state_dict(state["model"])
        if eval:
            self.eval()
            assert not self.training
            if self.use_weight_norm:
                self.remove_weight_norm()