Plachta commited on
Commit
8b2e962
·
verified ·
1 Parent(s): 50415c8

Update modules/hifigan/generator.py

Browse files
Files changed (1) hide show
  1. modules/hifigan/generator.py +454 -453
modules/hifigan/generator.py CHANGED
@@ -1,453 +1,454 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """HIFI-GAN"""
16
-
17
- import typing as tp
18
- import numpy as np
19
- from scipy.signal import get_window
20
- import torch
21
- import torch.nn as nn
22
- import torch.nn.functional as F
23
- from torch.nn import Conv1d
24
- from torch.nn import ConvTranspose1d
25
- from torch.nn.utils import remove_weight_norm
26
- from torch.nn.utils import weight_norm
27
- from torch.distributions.uniform import Uniform
28
-
29
- from torch import sin
30
- from torch.nn.parameter import Parameter
31
-
32
-
33
- """hifigan based generator implementation.
34
-
35
- This code is modified from https://github.com/jik876/hifi-gan
36
- ,https://github.com/kan-bayashi/ParallelWaveGAN and
37
- https://github.com/NVIDIA/BigVGAN
38
-
39
- """
40
- class Snake(nn.Module):
41
- '''
42
- Implementation of a sine-based periodic activation function
43
- Shape:
44
- - Input: (B, C, T)
45
- - Output: (B, C, T), same shape as the input
46
- Parameters:
47
- - alpha - trainable parameter
48
- References:
49
- - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
50
- https://arxiv.org/abs/2006.08195
51
- Examples:
52
- >>> a1 = snake(256)
53
- >>> x = torch.randn(256)
54
- >>> x = a1(x)
55
- '''
56
- def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
57
- '''
58
- Initialization.
59
- INPUT:
60
- - in_features: shape of the input
61
- - alpha: trainable parameter
62
- alpha is initialized to 1 by default, higher values = higher-frequency.
63
- alpha will be trained along with the rest of your model.
64
- '''
65
- super(Snake, self).__init__()
66
- self.in_features = in_features
67
-
68
- # initialize alpha
69
- self.alpha_logscale = alpha_logscale
70
- if self.alpha_logscale: # log scale alphas initialized to zeros
71
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
72
- else: # linear scale alphas initialized to ones
73
- self.alpha = Parameter(torch.ones(in_features) * alpha)
74
-
75
- self.alpha.requires_grad = alpha_trainable
76
-
77
- self.no_div_by_zero = 0.000000001
78
-
79
- def forward(self, x):
80
- '''
81
- Forward pass of the function.
82
- Applies the function to the input elementwise.
83
- Snake ∶= x + 1/a * sin^2 (xa)
84
- '''
85
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
86
- if self.alpha_logscale:
87
- alpha = torch.exp(alpha)
88
- x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
89
-
90
- return x
91
-
92
- def get_padding(kernel_size, dilation=1):
93
- return int((kernel_size * dilation - dilation) / 2)
94
-
95
-
96
- def init_weights(m, mean=0.0, std=0.01):
97
- classname = m.__class__.__name__
98
- if classname.find("Conv") != -1:
99
- m.weight.data.normal_(mean, std)
100
-
101
-
102
-
103
- class ResBlock(torch.nn.Module):
104
- """Residual block module in HiFiGAN/BigVGAN."""
105
- def __init__(
106
- self,
107
- channels: int = 512,
108
- kernel_size: int = 3,
109
- dilations: tp.List[int] = [1, 3, 5],
110
- ):
111
- super(ResBlock, self).__init__()
112
- self.convs1 = nn.ModuleList()
113
- self.convs2 = nn.ModuleList()
114
-
115
- for dilation in dilations:
116
- self.convs1.append(
117
- weight_norm(
118
- Conv1d(
119
- channels,
120
- channels,
121
- kernel_size,
122
- 1,
123
- dilation=dilation,
124
- padding=get_padding(kernel_size, dilation)
125
- )
126
- )
127
- )
128
- self.convs2.append(
129
- weight_norm(
130
- Conv1d(
131
- channels,
132
- channels,
133
- kernel_size,
134
- 1,
135
- dilation=1,
136
- padding=get_padding(kernel_size, 1)
137
- )
138
- )
139
- )
140
- self.convs1.apply(init_weights)
141
- self.convs2.apply(init_weights)
142
- self.activations1 = nn.ModuleList([
143
- Snake(channels, alpha_logscale=False)
144
- for _ in range(len(self.convs1))
145
- ])
146
- self.activations2 = nn.ModuleList([
147
- Snake(channels, alpha_logscale=False)
148
- for _ in range(len(self.convs2))
149
- ])
150
-
151
- def forward(self, x: torch.Tensor) -> torch.Tensor:
152
- for idx in range(len(self.convs1)):
153
- xt = self.activations1[idx](x)
154
- xt = self.convs1[idx](xt)
155
- xt = self.activations2[idx](xt)
156
- xt = self.convs2[idx](xt)
157
- x = xt + x
158
- return x
159
-
160
- def remove_weight_norm(self):
161
- for idx in range(len(self.convs1)):
162
- remove_weight_norm(self.convs1[idx])
163
- remove_weight_norm(self.convs2[idx])
164
-
165
- class SineGen(torch.nn.Module):
166
- """ Definition of sine generator
167
- SineGen(samp_rate, harmonic_num = 0,
168
- sine_amp = 0.1, noise_std = 0.003,
169
- voiced_threshold = 0,
170
- flag_for_pulse=False)
171
- samp_rate: sampling rate in Hz
172
- harmonic_num: number of harmonic overtones (default 0)
173
- sine_amp: amplitude of sine-wavefrom (default 0.1)
174
- noise_std: std of Gaussian noise (default 0.003)
175
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
176
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
177
- Note: when flag_for_pulse is True, the first time step of a voiced
178
- segment is always sin(np.pi) or cos(0)
179
- """
180
-
181
- def __init__(self, samp_rate, harmonic_num=0,
182
- sine_amp=0.1, noise_std=0.003,
183
- voiced_threshold=0):
184
- super(SineGen, self).__init__()
185
- self.sine_amp = sine_amp
186
- self.noise_std = noise_std
187
- self.harmonic_num = harmonic_num
188
- self.sampling_rate = samp_rate
189
- self.voiced_threshold = voiced_threshold
190
-
191
- def _f02uv(self, f0):
192
- # generate uv signal
193
- uv = (f0 > self.voiced_threshold).type(torch.float32)
194
- return uv
195
-
196
- @torch.no_grad()
197
- def forward(self, f0):
198
- """
199
- :param f0: [B, 1, sample_len], Hz
200
- :return: [B, 1, sample_len]
201
- """
202
-
203
- F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
204
- for i in range(self.harmonic_num + 1):
205
- F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
206
-
207
- theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
208
- u_dist = Uniform(low=-np.pi, high=np.pi)
209
- phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
210
- phase_vec[:, 0, :] = 0
211
-
212
- # generate sine waveforms
213
- sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
214
-
215
- # generate uv signal
216
- uv = self._f02uv(f0)
217
-
218
- # noise: for unvoiced should be similar to sine_amp
219
- # std = self.sine_amp/3 -> max value ~ self.sine_amp
220
- # . for voiced regions is self.noise_std
221
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
222
- noise = noise_amp * torch.randn_like(sine_waves)
223
-
224
- # first: set the unvoiced part to 0 by uv
225
- # then: additive noise
226
- sine_waves = sine_waves * uv + noise
227
- return sine_waves, uv, noise
228
-
229
-
230
- class SourceModuleHnNSF(torch.nn.Module):
231
- """ SourceModule for hn-nsf
232
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
233
- add_noise_std=0.003, voiced_threshod=0)
234
- sampling_rate: sampling_rate in Hz
235
- harmonic_num: number of harmonic above F0 (default: 0)
236
- sine_amp: amplitude of sine source signal (default: 0.1)
237
- add_noise_std: std of additive Gaussian noise (default: 0.003)
238
- note that amplitude of noise in unvoiced is decided
239
- by sine_amp
240
- voiced_threshold: threhold to set U/V given F0 (default: 0)
241
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
242
- F0_sampled (batchsize, length, 1)
243
- Sine_source (batchsize, length, 1)
244
- noise_source (batchsize, length 1)
245
- uv (batchsize, length, 1)
246
- """
247
-
248
- def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
249
- add_noise_std=0.003, voiced_threshod=0):
250
- super(SourceModuleHnNSF, self).__init__()
251
-
252
- self.sine_amp = sine_amp
253
- self.noise_std = add_noise_std
254
-
255
- # to produce sine waveforms
256
- self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
257
- sine_amp, add_noise_std, voiced_threshod)
258
-
259
- # to merge source harmonics into a single excitation
260
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
261
- self.l_tanh = torch.nn.Tanh()
262
-
263
- def forward(self, x):
264
- """
265
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
266
- F0_sampled (batchsize, length, 1)
267
- Sine_source (batchsize, length, 1)
268
- noise_source (batchsize, length 1)
269
- """
270
- # source for harmonic branch
271
- with torch.no_grad():
272
- sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
273
- sine_wavs = sine_wavs.transpose(1, 2)
274
- uv = uv.transpose(1, 2)
275
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
276
-
277
- # source for noise branch, in the same shape as uv
278
- noise = torch.randn_like(uv) * self.sine_amp / 3
279
- return sine_merge, noise, uv
280
-
281
-
282
- class HiFTGenerator(nn.Module):
283
- """
284
- HiFTNet Generator: Neural Source Filter + ISTFTNet
285
- https://arxiv.org/abs/2309.09493
286
- """
287
- def __init__(
288
- self,
289
- in_channels: int = 80,
290
- base_channels: int = 512,
291
- nb_harmonics: int = 8,
292
- sampling_rate: int = 22050,
293
- nsf_alpha: float = 0.1,
294
- nsf_sigma: float = 0.003,
295
- nsf_voiced_threshold: float = 10,
296
- upsample_rates: tp.List[int] = [8, 8],
297
- upsample_kernel_sizes: tp.List[int] = [16, 16],
298
- istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
299
- resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
300
- resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
301
- source_resblock_kernel_sizes: tp.List[int] = [7, 11],
302
- source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
303
- lrelu_slope: float = 0.1,
304
- audio_limit: float = 0.99,
305
- f0_predictor: torch.nn.Module = None,
306
- ):
307
- super(HiFTGenerator, self).__init__()
308
-
309
- self.out_channels = 1
310
- self.nb_harmonics = nb_harmonics
311
- self.sampling_rate = sampling_rate
312
- self.istft_params = istft_params
313
- self.lrelu_slope = lrelu_slope
314
- self.audio_limit = audio_limit
315
-
316
- self.num_kernels = len(resblock_kernel_sizes)
317
- self.num_upsamples = len(upsample_rates)
318
- self.m_source = SourceModuleHnNSF(
319
- sampling_rate=sampling_rate,
320
- upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
321
- harmonic_num=nb_harmonics,
322
- sine_amp=nsf_alpha,
323
- add_noise_std=nsf_sigma,
324
- voiced_threshod=nsf_voiced_threshold)
325
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
326
-
327
- self.conv_pre = weight_norm(
328
- Conv1d(in_channels, base_channels, 7, 1, padding=3)
329
- )
330
-
331
- # Up
332
- self.ups = nn.ModuleList()
333
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
334
- self.ups.append(
335
- weight_norm(
336
- ConvTranspose1d(
337
- base_channels // (2**i),
338
- base_channels // (2**(i + 1)),
339
- k,
340
- u,
341
- padding=(k - u) // 2,
342
- )
343
- )
344
- )
345
-
346
- # Down
347
- self.source_downs = nn.ModuleList()
348
- self.source_resblocks = nn.ModuleList()
349
- downsample_rates = [1] + upsample_rates[::-1][:-1]
350
- downsample_cum_rates = np.cumprod(downsample_rates)
351
- for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
352
- source_resblock_dilation_sizes)):
353
- if u == 1:
354
- self.source_downs.append(
355
- Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
356
- )
357
- else:
358
- self.source_downs.append(
359
- Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
360
- )
361
-
362
- self.source_resblocks.append(
363
- ResBlock(base_channels // (2 ** (i + 1)), k, d)
364
- )
365
-
366
- self.resblocks = nn.ModuleList()
367
- for i in range(len(self.ups)):
368
- ch = base_channels // (2**(i + 1))
369
- for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
370
- self.resblocks.append(ResBlock(ch, k, d))
371
-
372
- self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
373
- self.ups.apply(init_weights)
374
- self.conv_post.apply(init_weights)
375
- self.reflection_pad = nn.ReflectionPad1d((1, 0))
376
- self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
377
- self.f0_predictor = f0_predictor
378
-
379
- def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
380
- f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
381
-
382
- har_source, _, _ = self.m_source(f0)
383
- return har_source.transpose(1, 2)
384
-
385
- def _stft(self, x):
386
- spec = torch.stft(
387
- x,
388
- self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
389
- return_complex=True)
390
- spec = torch.view_as_real(spec) # [B, F, TT, 2]
391
- return spec[..., 0], spec[..., 1]
392
-
393
- def _istft(self, magnitude, phase):
394
- magnitude = torch.clip(magnitude, max=1e2)
395
- real = magnitude * torch.cos(phase)
396
- img = magnitude * torch.sin(phase)
397
- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
398
- return inverse_transform
399
-
400
- def forward(self, x: torch.Tensor) -> torch.Tensor:
401
- f0 = self.f0_predictor(x)
402
- s = self._f02source(f0)
403
-
404
- s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
405
- s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
406
-
407
- x = self.conv_pre(x)
408
- for i in range(self.num_upsamples):
409
- x = F.leaky_relu(x, self.lrelu_slope)
410
- x = self.ups[i](x)
411
-
412
- if i == self.num_upsamples - 1:
413
- x = self.reflection_pad(x)
414
-
415
- # fusion
416
- si = self.source_downs[i](s_stft)
417
- si = self.source_resblocks[i](si)
418
- x = x + si
419
-
420
- xs = None
421
- for j in range(self.num_kernels):
422
- if xs is None:
423
- xs = self.resblocks[i * self.num_kernels + j](x)
424
- else:
425
- xs += self.resblocks[i * self.num_kernels + j](x)
426
- x = xs / self.num_kernels
427
-
428
- x = F.leaky_relu(x)
429
- x = self.conv_post(x)
430
- magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
431
- phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
432
-
433
- x = self._istft(magnitude, phase)
434
- x = torch.clamp(x, -self.audio_limit, self.audio_limit)
435
- return x
436
-
437
- def remove_weight_norm(self):
438
- print('Removing weight norm...')
439
- for l in self.ups:
440
- remove_weight_norm(l)
441
- for l in self.resblocks:
442
- l.remove_weight_norm()
443
- remove_weight_norm(self.conv_pre)
444
- remove_weight_norm(self.conv_post)
445
- self.source_module.remove_weight_norm()
446
- for l in self.source_downs:
447
- remove_weight_norm(l)
448
- for l in self.source_resblocks:
449
- l.remove_weight_norm()
450
-
451
- @torch.inference_mode()
452
- def inference(self, mel: torch.Tensor) -> torch.Tensor:
453
- return self.forward(x=mel)
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ import typing as tp
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ from torch.nn.utils import weight_norm
27
+ from torch.distributions.uniform import Uniform
28
+
29
+ from torch import sin
30
+ from torch.nn.parameter import Parameter
31
+
32
+
33
+ """hifigan based generator implementation.
34
+
35
+ This code is modified from https://github.com/jik876/hifi-gan
36
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
37
+ https://github.com/NVIDIA/BigVGAN
38
+
39
+ """
40
+ class Snake(nn.Module):
41
+ '''
42
+ Implementation of a sine-based periodic activation function
43
+ Shape:
44
+ - Input: (B, C, T)
45
+ - Output: (B, C, T), same shape as the input
46
+ Parameters:
47
+ - alpha - trainable parameter
48
+ References:
49
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
50
+ https://arxiv.org/abs/2006.08195
51
+ Examples:
52
+ >>> a1 = snake(256)
53
+ >>> x = torch.randn(256)
54
+ >>> x = a1(x)
55
+ '''
56
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
57
+ '''
58
+ Initialization.
59
+ INPUT:
60
+ - in_features: shape of the input
61
+ - alpha: trainable parameter
62
+ alpha is initialized to 1 by default, higher values = higher-frequency.
63
+ alpha will be trained along with the rest of your model.
64
+ '''
65
+ super(Snake, self).__init__()
66
+ self.in_features = in_features
67
+
68
+ # initialize alpha
69
+ self.alpha_logscale = alpha_logscale
70
+ if self.alpha_logscale: # log scale alphas initialized to zeros
71
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
72
+ else: # linear scale alphas initialized to ones
73
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
74
+
75
+ self.alpha.requires_grad = alpha_trainable
76
+
77
+ self.no_div_by_zero = 0.000000001
78
+
79
+ def forward(self, x):
80
+ '''
81
+ Forward pass of the function.
82
+ Applies the function to the input elementwise.
83
+ Snake ∶= x + 1/a * sin^2 (xa)
84
+ '''
85
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
86
+ if self.alpha_logscale:
87
+ alpha = torch.exp(alpha)
88
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
89
+
90
+ return x
91
+
92
+ def get_padding(kernel_size, dilation=1):
93
+ return int((kernel_size * dilation - dilation) / 2)
94
+
95
+
96
+ def init_weights(m, mean=0.0, std=0.01):
97
+ classname = m.__class__.__name__
98
+ if classname.find("Conv") != -1:
99
+ m.weight.data.normal_(mean, std)
100
+
101
+
102
+
103
+ class ResBlock(torch.nn.Module):
104
+ """Residual block module in HiFiGAN/BigVGAN."""
105
+ def __init__(
106
+ self,
107
+ channels: int = 512,
108
+ kernel_size: int = 3,
109
+ dilations: tp.List[int] = [1, 3, 5],
110
+ ):
111
+ super(ResBlock, self).__init__()
112
+ self.convs1 = nn.ModuleList()
113
+ self.convs2 = nn.ModuleList()
114
+
115
+ for dilation in dilations:
116
+ self.convs1.append(
117
+ weight_norm(
118
+ Conv1d(
119
+ channels,
120
+ channels,
121
+ kernel_size,
122
+ 1,
123
+ dilation=dilation,
124
+ padding=get_padding(kernel_size, dilation)
125
+ )
126
+ )
127
+ )
128
+ self.convs2.append(
129
+ weight_norm(
130
+ Conv1d(
131
+ channels,
132
+ channels,
133
+ kernel_size,
134
+ 1,
135
+ dilation=1,
136
+ padding=get_padding(kernel_size, 1)
137
+ )
138
+ )
139
+ )
140
+ self.convs1.apply(init_weights)
141
+ self.convs2.apply(init_weights)
142
+ self.activations1 = nn.ModuleList([
143
+ Snake(channels, alpha_logscale=False)
144
+ for _ in range(len(self.convs1))
145
+ ])
146
+ self.activations2 = nn.ModuleList([
147
+ Snake(channels, alpha_logscale=False)
148
+ for _ in range(len(self.convs2))
149
+ ])
150
+
151
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
152
+ for idx in range(len(self.convs1)):
153
+ xt = self.activations1[idx](x)
154
+ xt = self.convs1[idx](xt)
155
+ xt = self.activations2[idx](xt)
156
+ xt = self.convs2[idx](xt)
157
+ x = xt + x
158
+ return x
159
+
160
+ def remove_weight_norm(self):
161
+ for idx in range(len(self.convs1)):
162
+ remove_weight_norm(self.convs1[idx])
163
+ remove_weight_norm(self.convs2[idx])
164
+
165
+ class SineGen(torch.nn.Module):
166
+ """ Definition of sine generator
167
+ SineGen(samp_rate, harmonic_num = 0,
168
+ sine_amp = 0.1, noise_std = 0.003,
169
+ voiced_threshold = 0,
170
+ flag_for_pulse=False)
171
+ samp_rate: sampling rate in Hz
172
+ harmonic_num: number of harmonic overtones (default 0)
173
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
174
+ noise_std: std of Gaussian noise (default 0.003)
175
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
176
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
177
+ Note: when flag_for_pulse is True, the first time step of a voiced
178
+ segment is always sin(np.pi) or cos(0)
179
+ """
180
+
181
+ def __init__(self, samp_rate, harmonic_num=0,
182
+ sine_amp=0.1, noise_std=0.003,
183
+ voiced_threshold=0):
184
+ super(SineGen, self).__init__()
185
+ self.sine_amp = sine_amp
186
+ self.noise_std = noise_std
187
+ self.harmonic_num = harmonic_num
188
+ self.sampling_rate = samp_rate
189
+ self.voiced_threshold = voiced_threshold
190
+
191
+ def _f02uv(self, f0):
192
+ # generate uv signal
193
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
194
+ return uv
195
+
196
+ @torch.no_grad()
197
+ def forward(self, f0):
198
+ """
199
+ :param f0: [B, 1, sample_len], Hz
200
+ :return: [B, 1, sample_len]
201
+ """
202
+
203
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
204
+ for i in range(self.harmonic_num + 1):
205
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
206
+
207
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
208
+ u_dist = Uniform(low=-np.pi, high=np.pi)
209
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
210
+ phase_vec[:, 0, :] = 0
211
+
212
+ # generate sine waveforms
213
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
214
+
215
+ # generate uv signal
216
+ uv = self._f02uv(f0)
217
+
218
+ # noise: for unvoiced should be similar to sine_amp
219
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
220
+ # . for voiced regions is self.noise_std
221
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
222
+ noise = noise_amp * torch.randn_like(sine_waves)
223
+
224
+ # first: set the unvoiced part to 0 by uv
225
+ # then: additive noise
226
+ sine_waves = sine_waves * uv + noise
227
+ return sine_waves, uv, noise
228
+
229
+
230
+ class SourceModuleHnNSF(torch.nn.Module):
231
+ """ SourceModule for hn-nsf
232
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
233
+ add_noise_std=0.003, voiced_threshod=0)
234
+ sampling_rate: sampling_rate in Hz
235
+ harmonic_num: number of harmonic above F0 (default: 0)
236
+ sine_amp: amplitude of sine source signal (default: 0.1)
237
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
238
+ note that amplitude of noise in unvoiced is decided
239
+ by sine_amp
240
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
241
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
242
+ F0_sampled (batchsize, length, 1)
243
+ Sine_source (batchsize, length, 1)
244
+ noise_source (batchsize, length 1)
245
+ uv (batchsize, length, 1)
246
+ """
247
+
248
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
249
+ add_noise_std=0.003, voiced_threshod=0):
250
+ super(SourceModuleHnNSF, self).__init__()
251
+
252
+ self.sine_amp = sine_amp
253
+ self.noise_std = add_noise_std
254
+
255
+ # to produce sine waveforms
256
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
257
+ sine_amp, add_noise_std, voiced_threshod)
258
+
259
+ # to merge source harmonics into a single excitation
260
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
261
+ self.l_tanh = torch.nn.Tanh()
262
+
263
+ def forward(self, x):
264
+ """
265
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
266
+ F0_sampled (batchsize, length, 1)
267
+ Sine_source (batchsize, length, 1)
268
+ noise_source (batchsize, length 1)
269
+ """
270
+ # source for harmonic branch
271
+ with torch.no_grad():
272
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
273
+ sine_wavs = sine_wavs.transpose(1, 2)
274
+ uv = uv.transpose(1, 2)
275
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
276
+
277
+ # source for noise branch, in the same shape as uv
278
+ noise = torch.randn_like(uv) * self.sine_amp / 3
279
+ return sine_merge, noise, uv
280
+
281
+
282
+ class HiFTGenerator(nn.Module):
283
+ """
284
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
285
+ https://arxiv.org/abs/2309.09493
286
+ """
287
+ def __init__(
288
+ self,
289
+ in_channels: int = 80,
290
+ base_channels: int = 512,
291
+ nb_harmonics: int = 8,
292
+ sampling_rate: int = 22050,
293
+ nsf_alpha: float = 0.1,
294
+ nsf_sigma: float = 0.003,
295
+ nsf_voiced_threshold: float = 10,
296
+ upsample_rates: tp.List[int] = [8, 8],
297
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
298
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
299
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
300
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
301
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
302
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
303
+ lrelu_slope: float = 0.1,
304
+ audio_limit: float = 0.99,
305
+ f0_predictor: torch.nn.Module = None,
306
+ ):
307
+ super(HiFTGenerator, self).__init__()
308
+
309
+ self.out_channels = 1
310
+ self.nb_harmonics = nb_harmonics
311
+ self.sampling_rate = sampling_rate
312
+ self.istft_params = istft_params
313
+ self.lrelu_slope = lrelu_slope
314
+ self.audio_limit = audio_limit
315
+
316
+ self.num_kernels = len(resblock_kernel_sizes)
317
+ self.num_upsamples = len(upsample_rates)
318
+ self.m_source = SourceModuleHnNSF(
319
+ sampling_rate=sampling_rate,
320
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
321
+ harmonic_num=nb_harmonics,
322
+ sine_amp=nsf_alpha,
323
+ add_noise_std=nsf_sigma,
324
+ voiced_threshod=nsf_voiced_threshold)
325
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
326
+
327
+ self.conv_pre = weight_norm(
328
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
329
+ )
330
+
331
+ # Up
332
+ self.ups = nn.ModuleList()
333
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
334
+ self.ups.append(
335
+ weight_norm(
336
+ ConvTranspose1d(
337
+ base_channels // (2**i),
338
+ base_channels // (2**(i + 1)),
339
+ k,
340
+ u,
341
+ padding=(k - u) // 2,
342
+ )
343
+ )
344
+ )
345
+
346
+ # Down
347
+ self.source_downs = nn.ModuleList()
348
+ self.source_resblocks = nn.ModuleList()
349
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
350
+ downsample_cum_rates = np.cumprod(downsample_rates)
351
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
352
+ source_resblock_dilation_sizes)):
353
+ if u == 1:
354
+ self.source_downs.append(
355
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
356
+ )
357
+ else:
358
+ self.source_downs.append(
359
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
360
+ )
361
+
362
+ self.source_resblocks.append(
363
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
364
+ )
365
+
366
+ self.resblocks = nn.ModuleList()
367
+ for i in range(len(self.ups)):
368
+ ch = base_channels // (2**(i + 1))
369
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
370
+ self.resblocks.append(ResBlock(ch, k, d))
371
+
372
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
373
+ self.ups.apply(init_weights)
374
+ self.conv_post.apply(init_weights)
375
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
376
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
377
+ self.f0_predictor = f0_predictor
378
+
379
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
380
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
381
+
382
+ har_source, _, _ = self.m_source(f0)
383
+ return har_source.transpose(1, 2)
384
+
385
+ def _stft(self, x):
386
+ spec = torch.stft(
387
+ x,
388
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
389
+ return_complex=True)
390
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
391
+ return spec[..., 0], spec[..., 1]
392
+
393
+ def _istft(self, magnitude, phase):
394
+ magnitude = torch.clip(magnitude, max=1e2)
395
+ real = magnitude * torch.cos(phase)
396
+ img = magnitude * torch.sin(phase)
397
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
398
+ return inverse_transform
399
+
400
+ def forward(self, x: torch.Tensor, f0=None) -> torch.Tensor:
401
+ if f0 is None:
402
+ f0 = self.f0_predictor(x)
403
+ s = self._f02source(f0)
404
+
405
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
406
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
407
+
408
+ x = self.conv_pre(x)
409
+ for i in range(self.num_upsamples):
410
+ x = F.leaky_relu(x, self.lrelu_slope)
411
+ x = self.ups[i](x)
412
+
413
+ if i == self.num_upsamples - 1:
414
+ x = self.reflection_pad(x)
415
+
416
+ # fusion
417
+ si = self.source_downs[i](s_stft)
418
+ si = self.source_resblocks[i](si)
419
+ x = x + si
420
+
421
+ xs = None
422
+ for j in range(self.num_kernels):
423
+ if xs is None:
424
+ xs = self.resblocks[i * self.num_kernels + j](x)
425
+ else:
426
+ xs += self.resblocks[i * self.num_kernels + j](x)
427
+ x = xs / self.num_kernels
428
+
429
+ x = F.leaky_relu(x)
430
+ x = self.conv_post(x)
431
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
432
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
433
+
434
+ x = self._istft(magnitude, phase)
435
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
436
+ return x
437
+
438
+ def remove_weight_norm(self):
439
+ print('Removing weight norm...')
440
+ for l in self.ups:
441
+ remove_weight_norm(l)
442
+ for l in self.resblocks:
443
+ l.remove_weight_norm()
444
+ remove_weight_norm(self.conv_pre)
445
+ remove_weight_norm(self.conv_post)
446
+ self.source_module.remove_weight_norm()
447
+ for l in self.source_downs:
448
+ remove_weight_norm(l)
449
+ for l in self.source_resblocks:
450
+ l.remove_weight_norm()
451
+
452
+ @torch.inference_mode()
453
+ def inference(self, mel: torch.Tensor, f0=None) -> torch.Tensor:
454
+ return self.forward(x=mel, f0=f0)