File size: 3,306 Bytes
8c70653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from torch import nn
from torch.nn.utils import weight_norm

from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack


class MelganGenerator(nn.Module):
    def __init__(
        self,
        in_channels=80,
        out_channels=1,
        proj_kernel=7,
        base_channels=512,
        upsample_factors=(8, 8, 2, 2),
        res_kernel=3,
        num_res_blocks=3,
    ):
        super().__init__()

        # assert model parameters
        assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number."

        # setup additional model parameters
        base_padding = (proj_kernel - 1) // 2
        act_slope = 0.2
        self.inference_padding = 2

        # initial layer
        layers = []
        layers += [
            nn.ReflectionPad1d(base_padding),
            weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)),
        ]

        # upsampling layers and residual stacks
        for idx, upsample_factor in enumerate(upsample_factors):
            layer_in_channels = base_channels // (2**idx)
            layer_out_channels = base_channels // (2 ** (idx + 1))
            layer_filter_size = upsample_factor * 2
            layer_stride = upsample_factor
            layer_output_padding = upsample_factor % 2
            layer_padding = upsample_factor // 2 + layer_output_padding
            layers += [
                nn.LeakyReLU(act_slope),
                weight_norm(
                    nn.ConvTranspose1d(
                        layer_in_channels,
                        layer_out_channels,
                        layer_filter_size,
                        stride=layer_stride,
                        padding=layer_padding,
                        output_padding=layer_output_padding,
                        bias=True,
                    )
                ),
                ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel),
            ]

        layers += [nn.LeakyReLU(act_slope)]

        # final layer
        layers += [
            nn.ReflectionPad1d(base_padding),
            weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)),
            nn.Tanh(),
        ]
        self.layers = nn.Sequential(*layers)

    def forward(self, c):
        return self.layers(c)

    def inference(self, c):
        c = c.to(self.layers[1].weight.device)
        c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
        return self.layers(c)

    def remove_weight_norm(self):
        for _, layer in enumerate(self.layers):
            if len(layer.state_dict()) != 0:
                try:
                    nn.utils.remove_weight_norm(layer)
                except ValueError:
                    layer.remove_weight_norm()

    def load_checkpoint(
        self, config, checkpoint_path, eval=False, cache=False
    ):  # pylint: disable=unused-argument, redefined-builtin
        state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
        self.load_state_dict(state["model"])
        if eval:
            self.eval()
            assert not self.training
            self.remove_weight_norm()