File size: 8,089 Bytes
64e7f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.FastDiff.module.util import calc_noise_scale_embedding
def swish(x):
    return x * torch.sigmoid(x)


# dilated conv layer with kaiming_normal initialization
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
        super(Conv, self).__init__()
        self.padding = dilation * (kernel_size - 1) // 2
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding)
        self.conv = nn.utils.weight_norm(self.conv)
        nn.init.kaiming_normal_(self.conv.weight)

    def forward(self, x):
        out = self.conv(x)
        return out


# conv1x1 layer with zero initialization
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed
class ZeroConv1d(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ZeroConv1d, self).__init__()
        self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
        self.conv.weight.data.zero_()
        self.conv.bias.data.zero_()

    def forward(self, x):
        out = self.conv(x)
        return out


# every residual block (named residual layer in paper)
# contains one noncausal dilated conv
class Residual_block(nn.Module):
    def __init__(self, res_channels, skip_channels, dilation, 
                 noise_scale_embed_dim_out, multiband=True):
        super(Residual_block, self).__init__()
        self.res_channels = res_channels

        # the layer-specific fc for noise scale embedding
        self.fc_t = nn.Linear(noise_scale_embed_dim_out, self.res_channels)

        # dilated conv layer
        self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation)

        # add mel spectrogram upsampler and conditioner conv1x1 layer
        self.upsample_conv2d = torch.nn.ModuleList()
        if multiband is True:
             params = 8
        else:
             params = 16
        for s in [params, params]:  #######  Very  Important!!!!!  #######
            conv_trans2d = torch.nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s))
            conv_trans2d = torch.nn.utils.weight_norm(conv_trans2d)
            torch.nn.init.kaiming_normal_(conv_trans2d.weight)
            self.upsample_conv2d.append(conv_trans2d)
        self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1)  # 80 is mel bands

        # residual conv1x1 layer, connect to next residual layer
        self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1)
        self.res_conv = nn.utils.weight_norm(self.res_conv)
        nn.init.kaiming_normal_(self.res_conv.weight)

        # skip conv1x1 layer, add to all skip outputs through skip connections
        self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
        self.skip_conv = nn.utils.weight_norm(self.skip_conv)
        nn.init.kaiming_normal_(self.skip_conv.weight)

    def forward(self, input_data):
        x, mel_spec, noise_scale_embed = input_data
        h = x
        B, C, L = x.shape   # B, res_channels, L
        assert C == self.res_channels

        # add in noise scale embedding
        part_t = self.fc_t(noise_scale_embed)
        part_t = part_t.view([B, self.res_channels, 1])
        h += part_t

        # dilated conv layer
        h = self.dilated_conv_layer(h)

        # add mel spectrogram as (local) conditioner
        assert mel_spec is not None

        # Upsample spectrogram to size of audio
        mel_spec = torch.unsqueeze(mel_spec, dim=1)  # (B, 1, 80, T')
        mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4)
        mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4)
        mel_spec = torch.squeeze(mel_spec, dim=1)

        assert(mel_spec.size(2) >= L)
        if mel_spec.size(2) > L:
            mel_spec = mel_spec[:, :, :L]

        mel_spec = self.mel_conv(mel_spec)
        h += mel_spec

        # gated-tanh nonlinearity
        out = torch.tanh(h[:,:self.res_channels,:]) * torch.sigmoid(h[:,self.res_channels:,:])

        # residual and skip outputs
        res = self.res_conv(out)
        assert x.shape == res.shape
        skip = self.skip_conv(out)

        return (x + res) * math.sqrt(0.5), skip  # normalize for training stability


class Residual_group(nn.Module):
    def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle, 
                 noise_scale_embed_dim_in, 
                 noise_scale_embed_dim_mid,
                 noise_scale_embed_dim_out, multiband):
        super(Residual_group, self).__init__()
        self.num_res_layers = num_res_layers
        self.noise_scale_embed_dim_in = noise_scale_embed_dim_in

        # the shared two fc layers for noise scale embedding
        self.fc_t1 = nn.Linear(noise_scale_embed_dim_in, noise_scale_embed_dim_mid)
        self.fc_t2 = nn.Linear(noise_scale_embed_dim_mid, noise_scale_embed_dim_out)

        # stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512
        self.residual_blocks = nn.ModuleList()
        for n in range(self.num_res_layers):
            self.residual_blocks.append(Residual_block(res_channels, skip_channels, 
                                                       dilation=2 ** (n % dilation_cycle),
                                                       noise_scale_embed_dim_out=noise_scale_embed_dim_out, multiband=multiband))

    def forward(self, input_data):
        x, mel_spectrogram, noise_scales = input_data

        # embed noise scale
        noise_scale_embed = calc_noise_scale_embedding(noise_scales, self.noise_scale_embed_dim_in)
        noise_scale_embed = swish(self.fc_t1(noise_scale_embed))
        noise_scale_embed = swish(self.fc_t2(noise_scale_embed))

        # pass all residual layers
        h = x
        skip = 0
        for n in range(self.num_res_layers):
            h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, noise_scale_embed))  # use the output from last residual layer
            skip += skip_n  # accumulate all skip outputs

        return skip * math.sqrt(1.0 / self.num_res_layers)  # normalize for training stability


class WaveNet_vocoder(nn.Module):
    def __init__(self, in_channels, res_channels, skip_channels, out_channels, 
                 num_res_layers, dilation_cycle, 
                 noise_scale_embed_dim_in, 
                 noise_scale_embed_dim_mid,
                 noise_scale_embed_dim_out, multiband):
        super(WaveNet_vocoder, self).__init__()

        # initial conv1x1 with relu
        self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU())
        
        # all residual layers
        self.residual_layer = Residual_group(res_channels=res_channels, 
                                             skip_channels=skip_channels, 
                                             num_res_layers=num_res_layers, 
                                             dilation_cycle=dilation_cycle,
                                             noise_scale_embed_dim_in=noise_scale_embed_dim_in,
                                             noise_scale_embed_dim_mid=noise_scale_embed_dim_mid,
                                             noise_scale_embed_dim_out=noise_scale_embed_dim_out, multiband=multiband)
        
        # final conv1x1 -> relu -> zeroconv1x1
        self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
                                        nn.ReLU(),
                                        ZeroConv1d(skip_channels, out_channels))

    def forward(self, input_data):
        audio, mel_spectrogram, noise_scales = input_data  # b x band x T, b x 80 x T', b x 1
        x = audio
        x = self.init_conv(x)
        x = self.residual_layer((x, mel_spectrogram, noise_scales))
        x = self.final_conv(x)

        return x