DJQmUKV commited on
Commit
9ef7359
·
1 Parent(s): 0598547

chore: sync with latest code, support RVC v2 model instead

Browse files
app_multi.py CHANGED
@@ -1,7 +1,10 @@
1
  from typing import Union
2
 
 
 
3
  import asyncio
4
  import json
 
5
  from os import path, getenv
6
 
7
  import gradio as gr
@@ -13,11 +16,11 @@ import librosa
13
 
14
  import edge_tts
15
 
16
- from config import device
17
  import util
18
  from infer_pack.models import (
19
- SynthesizerTrnMs256NSFsid,
20
- SynthesizerTrnMs256NSFsid_nono
21
  )
22
  from vc_infer_pipeline import VC
23
 
@@ -25,6 +28,45 @@ from vc_infer_pipeline import VC
25
  # Reference: https://huggingface.co/spaces/zomehwh/rvc-models/blob/main/app.py#L21 # noqa
26
  in_hf_space = getenv('SYSTEM') == 'spaces'
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  app_css = '''
30
  #model_info img {
@@ -45,11 +87,11 @@ app = gr.Blocks(
45
  )
46
 
47
  # Load hubert model
48
- hubert_model = util.load_hubert_model(device, 'hubert_base.pt')
49
  hubert_model.eval()
50
 
51
  # Load models
52
- multi_cfg = json.load(open('multi_config.json', 'r'))
53
  loaded_models = []
54
 
55
  for model_name in multi_cfg.get('models'):
@@ -69,24 +111,24 @@ for model_name in multi_cfg.get('models'):
69
  cpt['config'][-3] = cpt['weight']['emb_g.weight'].shape[0] # n_spk
70
 
71
  if_f0 = cpt.get('f0', 1)
72
- net_g: Union[SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono]
73
  if if_f0 == 1:
74
- net_g = SynthesizerTrnMs256NSFsid(
75
  *cpt['config'],
76
- is_half=util.is_half(device)
77
  )
78
  else:
79
- net_g = SynthesizerTrnMs256NSFsid_nono(*cpt['config'])
80
 
81
  del net_g.enc_q
82
 
83
  # According to original code, this thing seems necessary.
84
  print(net_g.load_state_dict(cpt['weight'], strict=False))
85
 
86
- net_g.eval().to(device)
87
- net_g = net_g.half() if util.is_half(device) else net_g.float()
88
 
89
- vc = VC(tgt_sr, device, util.is_half(device))
90
 
91
  loaded_models.append(dict(
92
  name=model_name,
@@ -104,7 +146,10 @@ tts_speakers_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_vo
104
 
105
 
106
  # https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/main/infer-web.py#L118 # noqa
107
- def vc_func(input_audio, model_index, pitch_adjust, f0_method, feat_ratio):
 
 
 
108
  if input_audio is None:
109
  return (None, 'Please provide input audio.')
110
 
@@ -138,27 +183,47 @@ def vc_func(input_audio, model_index, pitch_adjust, f0_method, feat_ratio):
138
 
139
  pitch_int = int(pitch_adjust)
140
 
 
 
 
 
 
141
  times = [0, 0, 0]
 
 
 
 
142
  output_audio = model['vc'].pipeline(
143
  hubert_model,
144
  model['net_g'],
145
  model['metadata'].get('speaker_id', 0),
146
  audio_npy,
 
147
  times,
148
  pitch_int,
149
  f0_method,
150
  path.join('model', model['name'], model['metadata']['feat_index']),
151
- path.join('model', model['name'], model['metadata']['feat_npy']),
152
  feat_ratio,
153
- model['if_f0']
 
 
 
 
 
 
 
 
 
 
154
  )
155
 
156
  print(f'npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s')
157
- return ((model['target_sr'], output_audio), 'Success')
158
 
159
 
160
  async def edge_tts_vc_func(
161
- input_text, model_index, tts_speaker, pitch_adjust, f0_method, feat_ratio
 
162
  ):
163
  if input_text is None:
164
  return (None, 'Please provide TTS text.')
@@ -176,7 +241,10 @@ async def edge_tts_vc_func(
176
  model_index,
177
  pitch_adjust,
178
  f0_method,
179
- feat_ratio
 
 
 
180
  )
181
 
182
 
@@ -210,9 +278,13 @@ def update_model_info(model_index):
210
  )
211
 
212
 
213
- def _example_vc(input_audio, model_index, pitch_adjust, f0_method, feat_ratio):
 
 
 
214
  (audio, message) = vc_func(
215
- input_audio, model_index, pitch_adjust, f0_method, feat_ratio
 
216
  )
217
  return (
218
  audio,
@@ -222,11 +294,12 @@ def _example_vc(input_audio, model_index, pitch_adjust, f0_method, feat_ratio):
222
 
223
 
224
  async def _example_edge_tts(
225
- input_text, model_index, tts_speaker, pitch_adjust, f0_method, feat_ratio
 
226
  ):
227
  (audio, message) = await edge_tts_vc_func(
228
  input_text, model_index, tts_speaker, pitch_adjust, f0_method,
229
- feat_ratio
230
  )
231
  return (
232
  audio,
@@ -280,13 +353,40 @@ with app:
280
  value='pm',
281
  interactive=True
282
  )
283
- feat_ratio = gr.Slider(
284
- label='Feature ratio',
285
- minimum=0,
286
- maximum=1,
287
- step=0.1,
288
- value=0.6
289
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  with gr.Column():
292
  # Model select
@@ -314,7 +414,10 @@ with app:
314
  output_msg = gr.Textbox(label='Output message')
315
 
316
  multi_examples = multi_cfg.get('examples')
317
- if multi_examples:
 
 
 
318
  with gr.Accordion('Sweet sweet examples', open=False):
319
  with gr.Row():
320
  # VC Example
@@ -328,8 +431,8 @@ with app:
328
  ],
329
  outputs=[output_audio, output_msg, model_info],
330
  fn=_example_vc,
331
- cache_examples=False,
332
- run_on_click=False
333
  )
334
 
335
  # Edge TTS Example
@@ -343,13 +446,16 @@ with app:
343
  ],
344
  outputs=[output_audio, output_msg, model_info],
345
  fn=_example_edge_tts,
346
- cache_examples=False,
347
- run_on_click=False
348
  )
349
 
350
  vc_convert_btn.click(
351
  vc_func,
352
- [input_audio, model_index, pitch_adjust, f0_method, feat_ratio],
 
 
 
353
  [output_audio, output_msg],
354
  api_name='audio_conversion'
355
  )
@@ -358,7 +464,7 @@ with app:
358
  edge_tts_vc_func,
359
  [
360
  tts_input, model_index, tts_speaker, pitch_adjust, f0_method,
361
- feat_ratio
362
  ],
363
  [output_audio, output_msg],
364
  api_name='tts_conversion'
@@ -375,5 +481,9 @@ with app:
375
  app.queue(
376
  concurrency_count=1,
377
  max_size=20,
378
- api_open=False
379
- ).launch()
 
 
 
 
 
1
  from typing import Union
2
 
3
+ from argparse import ArgumentParser
4
+
5
  import asyncio
6
  import json
7
+ import hashlib
8
  from os import path, getenv
9
 
10
  import gradio as gr
 
16
 
17
  import edge_tts
18
 
19
+ import config
20
  import util
21
  from infer_pack.models import (
22
+ SynthesizerTrnMs768NSFsid,
23
+ SynthesizerTrnMs768NSFsid_nono
24
  )
25
  from vc_infer_pipeline import VC
26
 
 
28
  # Reference: https://huggingface.co/spaces/zomehwh/rvc-models/blob/main/app.py#L21 # noqa
29
  in_hf_space = getenv('SYSTEM') == 'spaces'
30
 
31
+ # Argument parsing
32
+ arg_parser = ArgumentParser()
33
+ arg_parser.add_argument(
34
+ '--hubert',
35
+ default=getenv('RVC_HUBERT', 'hubert_base.pt'),
36
+ help='path to hubert base model (default: hubert_base.pt)'
37
+ )
38
+ arg_parser.add_argument(
39
+ '--config',
40
+ default=getenv('RVC_MULTI_CFG', 'multi_config.json'),
41
+ help='path to config file (default: multi_config.json)'
42
+ )
43
+ arg_parser.add_argument(
44
+ '--bind',
45
+ default=getenv('RVC_LISTEN_ADDR', '127.0.0.1'),
46
+ help='gradio server listen address (default: 127.0.0.1)'
47
+ )
48
+ arg_parser.add_argument(
49
+ '--port',
50
+ default=getenv('RVC_LISTEN_PORT', '7860'),
51
+ type=int,
52
+ help='gradio server listen port (default: 7860)'
53
+ )
54
+ arg_parser.add_argument(
55
+ '--share',
56
+ action='store_true',
57
+ help='let gradio create a public link for you'
58
+ )
59
+ arg_parser.add_argument(
60
+ '--api',
61
+ action='store_true',
62
+ help='enable api endpoint'
63
+ )
64
+ arg_parser.add_argument(
65
+ '--cache-examples',
66
+ action='store_true',
67
+ help='enable example caching, please remember delete gradio_cached_examples folder when example config has been modified' # noqa
68
+ )
69
+ args = arg_parser.parse_args()
70
 
71
  app_css = '''
72
  #model_info img {
 
87
  )
88
 
89
  # Load hubert model
90
+ hubert_model = util.load_hubert_model(config.device, args.hubert)
91
  hubert_model.eval()
92
 
93
  # Load models
94
+ multi_cfg = json.load(open(args.config, 'r'))
95
  loaded_models = []
96
 
97
  for model_name in multi_cfg.get('models'):
 
111
  cpt['config'][-3] = cpt['weight']['emb_g.weight'].shape[0] # n_spk
112
 
113
  if_f0 = cpt.get('f0', 1)
114
+ net_g: Union[SynthesizerTrnMs768NSFsid, SynthesizerTrnMs768NSFsid_nono]
115
  if if_f0 == 1:
116
+ net_g = SynthesizerTrnMs768NSFsid(
117
  *cpt['config'],
118
+ is_half=util.is_half(config.device)
119
  )
120
  else:
121
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt['config'])
122
 
123
  del net_g.enc_q
124
 
125
  # According to original code, this thing seems necessary.
126
  print(net_g.load_state_dict(cpt['weight'], strict=False))
127
 
128
+ net_g.eval().to(config.device)
129
+ net_g = net_g.half() if util.is_half(config.device) else net_g.float()
130
 
131
+ vc = VC(tgt_sr, config)
132
 
133
  loaded_models.append(dict(
134
  name=model_name,
 
146
 
147
 
148
  # https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/main/infer-web.py#L118 # noqa
149
+ def vc_func(
150
+ input_audio, model_index, pitch_adjust, f0_method, feat_ratio,
151
+ filter_radius, rms_mix_rate, resample_option
152
+ ):
153
  if input_audio is None:
154
  return (None, 'Please provide input audio.')
155
 
 
183
 
184
  pitch_int = int(pitch_adjust)
185
 
186
+ resample = (
187
+ 0 if resample_option == 'Disable resampling'
188
+ else int(resample_option)
189
+ )
190
+
191
  times = [0, 0, 0]
192
+
193
+ checksum = hashlib.sha512()
194
+ checksum.update(audio_npy.tobytes())
195
+
196
  output_audio = model['vc'].pipeline(
197
  hubert_model,
198
  model['net_g'],
199
  model['metadata'].get('speaker_id', 0),
200
  audio_npy,
201
+ checksum.hexdigest(),
202
  times,
203
  pitch_int,
204
  f0_method,
205
  path.join('model', model['name'], model['metadata']['feat_index']),
 
206
  feat_ratio,
207
+ model['if_f0'],
208
+ filter_radius,
209
+ model['target_sr'],
210
+ resample,
211
+ rms_mix_rate,
212
+ 'v2'
213
+ )
214
+
215
+ out_sr = (
216
+ resample if resample >= 16000 and model['target_sr'] != resample
217
+ else model['target_sr']
218
  )
219
 
220
  print(f'npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s')
221
+ return ((out_sr, output_audio), 'Success')
222
 
223
 
224
  async def edge_tts_vc_func(
225
+ input_text, model_index, tts_speaker, pitch_adjust, f0_method, feat_ratio,
226
+ filter_radius, rms_mix_rate, resample_option
227
  ):
228
  if input_text is None:
229
  return (None, 'Please provide TTS text.')
 
241
  model_index,
242
  pitch_adjust,
243
  f0_method,
244
+ feat_ratio,
245
+ filter_radius,
246
+ rms_mix_rate,
247
+ resample_option
248
  )
249
 
250
 
 
278
  )
279
 
280
 
281
+ def _example_vc(
282
+ input_audio, model_index, pitch_adjust, f0_method, feat_ratio,
283
+ filter_radius, rms_mix_rate, resample_option
284
+ ):
285
  (audio, message) = vc_func(
286
+ input_audio, model_index, pitch_adjust, f0_method, feat_ratio,
287
+ filter_radius, rms_mix_rate, resample_option
288
  )
289
  return (
290
  audio,
 
294
 
295
 
296
  async def _example_edge_tts(
297
+ input_text, model_index, tts_speaker, pitch_adjust, f0_method, feat_ratio,
298
+ filter_radius, rms_mix_rate, resample_option
299
  ):
300
  (audio, message) = await edge_tts_vc_func(
301
  input_text, model_index, tts_speaker, pitch_adjust, f0_method,
302
+ feat_ratio, filter_radius, rms_mix_rate, resample_option
303
  )
304
  return (
305
  audio,
 
353
  value='pm',
354
  interactive=True
355
  )
356
+
357
+ with gr.Accordion('Advanced options', open=False):
358
+ feat_ratio = gr.Slider(
359
+ label='Feature ratio',
360
+ minimum=0,
361
+ maximum=1,
362
+ step=0.1,
363
+ value=0.6
364
+ )
365
+ filter_radius = gr.Slider(
366
+ label='Filter radius',
367
+ minimum=0,
368
+ maximum=7,
369
+ step=1,
370
+ value=3
371
+ )
372
+ rms_mix_rate = gr.Slider(
373
+ label='Volume envelope mix rate',
374
+ minimum=0,
375
+ maximum=1,
376
+ step=0.1,
377
+ value=1
378
+ )
379
+ resample_rate = gr.Dropdown(
380
+ [
381
+ 'Disable resampling',
382
+ '16000',
383
+ '22050',
384
+ '44100',
385
+ '48000'
386
+ ],
387
+ label='Resample rate',
388
+ value='Disable resampling'
389
+ )
390
 
391
  with gr.Column():
392
  # Model select
 
414
  output_msg = gr.Textbox(label='Output message')
415
 
416
  multi_examples = multi_cfg.get('examples')
417
+ if (
418
+ multi_examples and
419
+ multi_examples.get('vc') and multi_examples.get('tts_vc')
420
+ ):
421
  with gr.Accordion('Sweet sweet examples', open=False):
422
  with gr.Row():
423
  # VC Example
 
431
  ],
432
  outputs=[output_audio, output_msg, model_info],
433
  fn=_example_vc,
434
+ cache_examples=args.cache_examples,
435
+ run_on_click=args.cache_examples
436
  )
437
 
438
  # Edge TTS Example
 
446
  ],
447
  outputs=[output_audio, output_msg, model_info],
448
  fn=_example_edge_tts,
449
+ cache_examples=args.cache_examples,
450
+ run_on_click=args.cache_examples
451
  )
452
 
453
  vc_convert_btn.click(
454
  vc_func,
455
+ [
456
+ input_audio, model_index, pitch_adjust, f0_method, feat_ratio,
457
+ filter_radius, rms_mix_rate, resample_rate
458
+ ],
459
  [output_audio, output_msg],
460
  api_name='audio_conversion'
461
  )
 
464
  edge_tts_vc_func,
465
  [
466
  tts_input, model_index, tts_speaker, pitch_adjust, f0_method,
467
+ feat_ratio, filter_radius, rms_mix_rate, resample_rate
468
  ],
469
  [output_audio, output_msg],
470
  api_name='tts_conversion'
 
481
  app.queue(
482
  concurrency_count=1,
483
  max_size=20,
484
+ api_open=args.api
485
+ ).launch(
486
+ server_name=args.bind,
487
+ server_port=args.port,
488
+ share=args.share
489
+ )
config.py CHANGED
@@ -10,8 +10,9 @@ device = (
10
  else 'cpu'
11
  )
12
  )
 
13
 
14
- x_pad = 3 if util.is_half(device) else 1
15
- x_query = 10 if util.is_half(device) else 6
16
- x_center = 60 if util.is_half(device) else 38
17
- x_max = 65 if util.is_half(device) else 41
 
10
  else 'cpu'
11
  )
12
  )
13
+ is_half = util.is_half(device)
14
 
15
+ x_pad = 3 if is_half else 1
16
+ x_query = 10 if is_half else 6
17
+ x_center = 60 if is_half else 38
18
+ x_max = 65 if is_half else 41
infer_pack/attentions.py CHANGED
@@ -1,417 +1,417 @@
1
- import copy
2
- import math
3
- import numpy as np
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
-
8
- from infer_pack import commons
9
- from infer_pack import modules
10
- from infer_pack.modules import LayerNorm
11
-
12
-
13
- class Encoder(nn.Module):
14
- def __init__(
15
- self,
16
- hidden_channels,
17
- filter_channels,
18
- n_heads,
19
- n_layers,
20
- kernel_size=1,
21
- p_dropout=0.0,
22
- window_size=10,
23
- **kwargs
24
- ):
25
- super().__init__()
26
- self.hidden_channels = hidden_channels
27
- self.filter_channels = filter_channels
28
- self.n_heads = n_heads
29
- self.n_layers = n_layers
30
- self.kernel_size = kernel_size
31
- self.p_dropout = p_dropout
32
- self.window_size = window_size
33
-
34
- self.drop = nn.Dropout(p_dropout)
35
- self.attn_layers = nn.ModuleList()
36
- self.norm_layers_1 = nn.ModuleList()
37
- self.ffn_layers = nn.ModuleList()
38
- self.norm_layers_2 = nn.ModuleList()
39
- for i in range(self.n_layers):
40
- self.attn_layers.append(
41
- MultiHeadAttention(
42
- hidden_channels,
43
- hidden_channels,
44
- n_heads,
45
- p_dropout=p_dropout,
46
- window_size=window_size,
47
- )
48
- )
49
- self.norm_layers_1.append(LayerNorm(hidden_channels))
50
- self.ffn_layers.append(
51
- FFN(
52
- hidden_channels,
53
- hidden_channels,
54
- filter_channels,
55
- kernel_size,
56
- p_dropout=p_dropout,
57
- )
58
- )
59
- self.norm_layers_2.append(LayerNorm(hidden_channels))
60
-
61
- def forward(self, x, x_mask):
62
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
63
- x = x * x_mask
64
- for i in range(self.n_layers):
65
- y = self.attn_layers[i](x, x, attn_mask)
66
- y = self.drop(y)
67
- x = self.norm_layers_1[i](x + y)
68
-
69
- y = self.ffn_layers[i](x, x_mask)
70
- y = self.drop(y)
71
- x = self.norm_layers_2[i](x + y)
72
- x = x * x_mask
73
- return x
74
-
75
-
76
- class Decoder(nn.Module):
77
- def __init__(
78
- self,
79
- hidden_channels,
80
- filter_channels,
81
- n_heads,
82
- n_layers,
83
- kernel_size=1,
84
- p_dropout=0.0,
85
- proximal_bias=False,
86
- proximal_init=True,
87
- **kwargs
88
- ):
89
- super().__init__()
90
- self.hidden_channels = hidden_channels
91
- self.filter_channels = filter_channels
92
- self.n_heads = n_heads
93
- self.n_layers = n_layers
94
- self.kernel_size = kernel_size
95
- self.p_dropout = p_dropout
96
- self.proximal_bias = proximal_bias
97
- self.proximal_init = proximal_init
98
-
99
- self.drop = nn.Dropout(p_dropout)
100
- self.self_attn_layers = nn.ModuleList()
101
- self.norm_layers_0 = nn.ModuleList()
102
- self.encdec_attn_layers = nn.ModuleList()
103
- self.norm_layers_1 = nn.ModuleList()
104
- self.ffn_layers = nn.ModuleList()
105
- self.norm_layers_2 = nn.ModuleList()
106
- for i in range(self.n_layers):
107
- self.self_attn_layers.append(
108
- MultiHeadAttention(
109
- hidden_channels,
110
- hidden_channels,
111
- n_heads,
112
- p_dropout=p_dropout,
113
- proximal_bias=proximal_bias,
114
- proximal_init=proximal_init,
115
- )
116
- )
117
- self.norm_layers_0.append(LayerNorm(hidden_channels))
118
- self.encdec_attn_layers.append(
119
- MultiHeadAttention(
120
- hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
121
- )
122
- )
123
- self.norm_layers_1.append(LayerNorm(hidden_channels))
124
- self.ffn_layers.append(
125
- FFN(
126
- hidden_channels,
127
- hidden_channels,
128
- filter_channels,
129
- kernel_size,
130
- p_dropout=p_dropout,
131
- causal=True,
132
- )
133
- )
134
- self.norm_layers_2.append(LayerNorm(hidden_channels))
135
-
136
- def forward(self, x, x_mask, h, h_mask):
137
- """
138
- x: decoder input
139
- h: encoder output
140
- """
141
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
142
- device=x.device, dtype=x.dtype
143
- )
144
- encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
145
- x = x * x_mask
146
- for i in range(self.n_layers):
147
- y = self.self_attn_layers[i](x, x, self_attn_mask)
148
- y = self.drop(y)
149
- x = self.norm_layers_0[i](x + y)
150
-
151
- y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
152
- y = self.drop(y)
153
- x = self.norm_layers_1[i](x + y)
154
-
155
- y = self.ffn_layers[i](x, x_mask)
156
- y = self.drop(y)
157
- x = self.norm_layers_2[i](x + y)
158
- x = x * x_mask
159
- return x
160
-
161
-
162
- class MultiHeadAttention(nn.Module):
163
- def __init__(
164
- self,
165
- channels,
166
- out_channels,
167
- n_heads,
168
- p_dropout=0.0,
169
- window_size=None,
170
- heads_share=True,
171
- block_length=None,
172
- proximal_bias=False,
173
- proximal_init=False,
174
- ):
175
- super().__init__()
176
- assert channels % n_heads == 0
177
-
178
- self.channels = channels
179
- self.out_channels = out_channels
180
- self.n_heads = n_heads
181
- self.p_dropout = p_dropout
182
- self.window_size = window_size
183
- self.heads_share = heads_share
184
- self.block_length = block_length
185
- self.proximal_bias = proximal_bias
186
- self.proximal_init = proximal_init
187
- self.attn = None
188
-
189
- self.k_channels = channels // n_heads
190
- self.conv_q = nn.Conv1d(channels, channels, 1)
191
- self.conv_k = nn.Conv1d(channels, channels, 1)
192
- self.conv_v = nn.Conv1d(channels, channels, 1)
193
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
194
- self.drop = nn.Dropout(p_dropout)
195
-
196
- if window_size is not None:
197
- n_heads_rel = 1 if heads_share else n_heads
198
- rel_stddev = self.k_channels**-0.5
199
- self.emb_rel_k = nn.Parameter(
200
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
201
- * rel_stddev
202
- )
203
- self.emb_rel_v = nn.Parameter(
204
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205
- * rel_stddev
206
- )
207
-
208
- nn.init.xavier_uniform_(self.conv_q.weight)
209
- nn.init.xavier_uniform_(self.conv_k.weight)
210
- nn.init.xavier_uniform_(self.conv_v.weight)
211
- if proximal_init:
212
- with torch.no_grad():
213
- self.conv_k.weight.copy_(self.conv_q.weight)
214
- self.conv_k.bias.copy_(self.conv_q.bias)
215
-
216
- def forward(self, x, c, attn_mask=None):
217
- q = self.conv_q(x)
218
- k = self.conv_k(c)
219
- v = self.conv_v(c)
220
-
221
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
222
-
223
- x = self.conv_o(x)
224
- return x
225
-
226
- def attention(self, query, key, value, mask=None):
227
- # reshape [b, d, t] -> [b, n_h, t, d_k]
228
- b, d, t_s, t_t = (*key.size(), query.size(2))
229
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
230
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
232
-
233
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
234
- if self.window_size is not None:
235
- assert (
236
- t_s == t_t
237
- ), "Relative attention is only available for self-attention."
238
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
239
- rel_logits = self._matmul_with_relative_keys(
240
- query / math.sqrt(self.k_channels), key_relative_embeddings
241
- )
242
- scores_local = self._relative_position_to_absolute_position(rel_logits)
243
- scores = scores + scores_local
244
- if self.proximal_bias:
245
- assert t_s == t_t, "Proximal bias is only available for self-attention."
246
- scores = scores + self._attention_bias_proximal(t_s).to(
247
- device=scores.device, dtype=scores.dtype
248
- )
249
- if mask is not None:
250
- scores = scores.masked_fill(mask == 0, -1e4)
251
- if self.block_length is not None:
252
- assert (
253
- t_s == t_t
254
- ), "Local attention is only available for self-attention."
255
- block_mask = (
256
- torch.ones_like(scores)
257
- .triu(-self.block_length)
258
- .tril(self.block_length)
259
- )
260
- scores = scores.masked_fill(block_mask == 0, -1e4)
261
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
262
- p_attn = self.drop(p_attn)
263
- output = torch.matmul(p_attn, value)
264
- if self.window_size is not None:
265
- relative_weights = self._absolute_position_to_relative_position(p_attn)
266
- value_relative_embeddings = self._get_relative_embeddings(
267
- self.emb_rel_v, t_s
268
- )
269
- output = output + self._matmul_with_relative_values(
270
- relative_weights, value_relative_embeddings
271
- )
272
- output = (
273
- output.transpose(2, 3).contiguous().view(b, d, t_t)
274
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
275
- return output, p_attn
276
-
277
- def _matmul_with_relative_values(self, x, y):
278
- """
279
- x: [b, h, l, m]
280
- y: [h or 1, m, d]
281
- ret: [b, h, l, d]
282
- """
283
- ret = torch.matmul(x, y.unsqueeze(0))
284
- return ret
285
-
286
- def _matmul_with_relative_keys(self, x, y):
287
- """
288
- x: [b, h, l, d]
289
- y: [h or 1, m, d]
290
- ret: [b, h, l, m]
291
- """
292
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
293
- return ret
294
-
295
- def _get_relative_embeddings(self, relative_embeddings, length):
296
- max_relative_position = 2 * self.window_size + 1
297
- # Pad first before slice to avoid using cond ops.
298
- pad_length = max(length - (self.window_size + 1), 0)
299
- slice_start_position = max((self.window_size + 1) - length, 0)
300
- slice_end_position = slice_start_position + 2 * length - 1
301
- if pad_length > 0:
302
- padded_relative_embeddings = F.pad(
303
- relative_embeddings,
304
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
305
- )
306
- else:
307
- padded_relative_embeddings = relative_embeddings
308
- used_relative_embeddings = padded_relative_embeddings[
309
- :, slice_start_position:slice_end_position
310
- ]
311
- return used_relative_embeddings
312
-
313
- def _relative_position_to_absolute_position(self, x):
314
- """
315
- x: [b, h, l, 2*l-1]
316
- ret: [b, h, l, l]
317
- """
318
- batch, heads, length, _ = x.size()
319
- # Concat columns of pad to shift from relative to absolute indexing.
320
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
321
-
322
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
323
- x_flat = x.view([batch, heads, length * 2 * length])
324
- x_flat = F.pad(
325
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
326
- )
327
-
328
- # Reshape and slice out the padded elements.
329
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
330
- :, :, :length, length - 1 :
331
- ]
332
- return x_final
333
-
334
- def _absolute_position_to_relative_position(self, x):
335
- """
336
- x: [b, h, l, l]
337
- ret: [b, h, l, 2*l-1]
338
- """
339
- batch, heads, length, _ = x.size()
340
- # padd along column
341
- x = F.pad(
342
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
343
- )
344
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
345
- # add 0's in the beginning that will skew the elements after reshape
346
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
347
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348
- return x_final
349
-
350
- def _attention_bias_proximal(self, length):
351
- """Bias for self-attention to encourage attention to close positions.
352
- Args:
353
- length: an integer scalar.
354
- Returns:
355
- a Tensor with shape [1, 1, length, length]
356
- """
357
- r = torch.arange(length, dtype=torch.float32)
358
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360
-
361
-
362
- class FFN(nn.Module):
363
- def __init__(
364
- self,
365
- in_channels,
366
- out_channels,
367
- filter_channels,
368
- kernel_size,
369
- p_dropout=0.0,
370
- activation=None,
371
- causal=False,
372
- ):
373
- super().__init__()
374
- self.in_channels = in_channels
375
- self.out_channels = out_channels
376
- self.filter_channels = filter_channels
377
- self.kernel_size = kernel_size
378
- self.p_dropout = p_dropout
379
- self.activation = activation
380
- self.causal = causal
381
-
382
- if causal:
383
- self.padding = self._causal_padding
384
- else:
385
- self.padding = self._same_padding
386
-
387
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
388
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
389
- self.drop = nn.Dropout(p_dropout)
390
-
391
- def forward(self, x, x_mask):
392
- x = self.conv_1(self.padding(x * x_mask))
393
- if self.activation == "gelu":
394
- x = x * torch.sigmoid(1.702 * x)
395
- else:
396
- x = torch.relu(x)
397
- x = self.drop(x)
398
- x = self.conv_2(self.padding(x * x_mask))
399
- return x * x_mask
400
-
401
- def _causal_padding(self, x):
402
- if self.kernel_size == 1:
403
- return x
404
- pad_l = self.kernel_size - 1
405
- pad_r = 0
406
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
407
- x = F.pad(x, commons.convert_pad_shape(padding))
408
- return x
409
-
410
- def _same_padding(self, x):
411
- if self.kernel_size == 1:
412
- return x
413
- pad_l = (self.kernel_size - 1) // 2
414
- pad_r = self.kernel_size // 2
415
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
416
- x = F.pad(x, commons.convert_pad_shape(padding))
417
- return x
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from infer_pack import commons
9
+ from infer_pack import modules
10
+ from infer_pack.modules import LayerNorm
11
+
12
+
13
+ class Encoder(nn.Module):
14
+ def __init__(
15
+ self,
16
+ hidden_channels,
17
+ filter_channels,
18
+ n_heads,
19
+ n_layers,
20
+ kernel_size=1,
21
+ p_dropout=0.0,
22
+ window_size=10,
23
+ **kwargs
24
+ ):
25
+ super().__init__()
26
+ self.hidden_channels = hidden_channels
27
+ self.filter_channels = filter_channels
28
+ self.n_heads = n_heads
29
+ self.n_layers = n_layers
30
+ self.kernel_size = kernel_size
31
+ self.p_dropout = p_dropout
32
+ self.window_size = window_size
33
+
34
+ self.drop = nn.Dropout(p_dropout)
35
+ self.attn_layers = nn.ModuleList()
36
+ self.norm_layers_1 = nn.ModuleList()
37
+ self.ffn_layers = nn.ModuleList()
38
+ self.norm_layers_2 = nn.ModuleList()
39
+ for i in range(self.n_layers):
40
+ self.attn_layers.append(
41
+ MultiHeadAttention(
42
+ hidden_channels,
43
+ hidden_channels,
44
+ n_heads,
45
+ p_dropout=p_dropout,
46
+ window_size=window_size,
47
+ )
48
+ )
49
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
50
+ self.ffn_layers.append(
51
+ FFN(
52
+ hidden_channels,
53
+ hidden_channels,
54
+ filter_channels,
55
+ kernel_size,
56
+ p_dropout=p_dropout,
57
+ )
58
+ )
59
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
60
+
61
+ def forward(self, x, x_mask):
62
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
63
+ x = x * x_mask
64
+ for i in range(self.n_layers):
65
+ y = self.attn_layers[i](x, x, attn_mask)
66
+ y = self.drop(y)
67
+ x = self.norm_layers_1[i](x + y)
68
+
69
+ y = self.ffn_layers[i](x, x_mask)
70
+ y = self.drop(y)
71
+ x = self.norm_layers_2[i](x + y)
72
+ x = x * x_mask
73
+ return x
74
+
75
+
76
+ class Decoder(nn.Module):
77
+ def __init__(
78
+ self,
79
+ hidden_channels,
80
+ filter_channels,
81
+ n_heads,
82
+ n_layers,
83
+ kernel_size=1,
84
+ p_dropout=0.0,
85
+ proximal_bias=False,
86
+ proximal_init=True,
87
+ **kwargs
88
+ ):
89
+ super().__init__()
90
+ self.hidden_channels = hidden_channels
91
+ self.filter_channels = filter_channels
92
+ self.n_heads = n_heads
93
+ self.n_layers = n_layers
94
+ self.kernel_size = kernel_size
95
+ self.p_dropout = p_dropout
96
+ self.proximal_bias = proximal_bias
97
+ self.proximal_init = proximal_init
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.self_attn_layers = nn.ModuleList()
101
+ self.norm_layers_0 = nn.ModuleList()
102
+ self.encdec_attn_layers = nn.ModuleList()
103
+ self.norm_layers_1 = nn.ModuleList()
104
+ self.ffn_layers = nn.ModuleList()
105
+ self.norm_layers_2 = nn.ModuleList()
106
+ for i in range(self.n_layers):
107
+ self.self_attn_layers.append(
108
+ MultiHeadAttention(
109
+ hidden_channels,
110
+ hidden_channels,
111
+ n_heads,
112
+ p_dropout=p_dropout,
113
+ proximal_bias=proximal_bias,
114
+ proximal_init=proximal_init,
115
+ )
116
+ )
117
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
118
+ self.encdec_attn_layers.append(
119
+ MultiHeadAttention(
120
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
121
+ )
122
+ )
123
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
124
+ self.ffn_layers.append(
125
+ FFN(
126
+ hidden_channels,
127
+ hidden_channels,
128
+ filter_channels,
129
+ kernel_size,
130
+ p_dropout=p_dropout,
131
+ causal=True,
132
+ )
133
+ )
134
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
135
+
136
+ def forward(self, x, x_mask, h, h_mask):
137
+ """
138
+ x: decoder input
139
+ h: encoder output
140
+ """
141
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
142
+ device=x.device, dtype=x.dtype
143
+ )
144
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
145
+ x = x * x_mask
146
+ for i in range(self.n_layers):
147
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
148
+ y = self.drop(y)
149
+ x = self.norm_layers_0[i](x + y)
150
+
151
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
152
+ y = self.drop(y)
153
+ x = self.norm_layers_1[i](x + y)
154
+
155
+ y = self.ffn_layers[i](x, x_mask)
156
+ y = self.drop(y)
157
+ x = self.norm_layers_2[i](x + y)
158
+ x = x * x_mask
159
+ return x
160
+
161
+
162
+ class MultiHeadAttention(nn.Module):
163
+ def __init__(
164
+ self,
165
+ channels,
166
+ out_channels,
167
+ n_heads,
168
+ p_dropout=0.0,
169
+ window_size=None,
170
+ heads_share=True,
171
+ block_length=None,
172
+ proximal_bias=False,
173
+ proximal_init=False,
174
+ ):
175
+ super().__init__()
176
+ assert channels % n_heads == 0
177
+
178
+ self.channels = channels
179
+ self.out_channels = out_channels
180
+ self.n_heads = n_heads
181
+ self.p_dropout = p_dropout
182
+ self.window_size = window_size
183
+ self.heads_share = heads_share
184
+ self.block_length = block_length
185
+ self.proximal_bias = proximal_bias
186
+ self.proximal_init = proximal_init
187
+ self.attn = None
188
+
189
+ self.k_channels = channels // n_heads
190
+ self.conv_q = nn.Conv1d(channels, channels, 1)
191
+ self.conv_k = nn.Conv1d(channels, channels, 1)
192
+ self.conv_v = nn.Conv1d(channels, channels, 1)
193
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
194
+ self.drop = nn.Dropout(p_dropout)
195
+
196
+ if window_size is not None:
197
+ n_heads_rel = 1 if heads_share else n_heads
198
+ rel_stddev = self.k_channels**-0.5
199
+ self.emb_rel_k = nn.Parameter(
200
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
201
+ * rel_stddev
202
+ )
203
+ self.emb_rel_v = nn.Parameter(
204
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205
+ * rel_stddev
206
+ )
207
+
208
+ nn.init.xavier_uniform_(self.conv_q.weight)
209
+ nn.init.xavier_uniform_(self.conv_k.weight)
210
+ nn.init.xavier_uniform_(self.conv_v.weight)
211
+ if proximal_init:
212
+ with torch.no_grad():
213
+ self.conv_k.weight.copy_(self.conv_q.weight)
214
+ self.conv_k.bias.copy_(self.conv_q.bias)
215
+
216
+ def forward(self, x, c, attn_mask=None):
217
+ q = self.conv_q(x)
218
+ k = self.conv_k(c)
219
+ v = self.conv_v(c)
220
+
221
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
222
+
223
+ x = self.conv_o(x)
224
+ return x
225
+
226
+ def attention(self, query, key, value, mask=None):
227
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
228
+ b, d, t_s, t_t = (*key.size(), query.size(2))
229
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
230
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
232
+
233
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
234
+ if self.window_size is not None:
235
+ assert (
236
+ t_s == t_t
237
+ ), "Relative attention is only available for self-attention."
238
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
239
+ rel_logits = self._matmul_with_relative_keys(
240
+ query / math.sqrt(self.k_channels), key_relative_embeddings
241
+ )
242
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
243
+ scores = scores + scores_local
244
+ if self.proximal_bias:
245
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
246
+ scores = scores + self._attention_bias_proximal(t_s).to(
247
+ device=scores.device, dtype=scores.dtype
248
+ )
249
+ if mask is not None:
250
+ scores = scores.masked_fill(mask == 0, -1e4)
251
+ if self.block_length is not None:
252
+ assert (
253
+ t_s == t_t
254
+ ), "Local attention is only available for self-attention."
255
+ block_mask = (
256
+ torch.ones_like(scores)
257
+ .triu(-self.block_length)
258
+ .tril(self.block_length)
259
+ )
260
+ scores = scores.masked_fill(block_mask == 0, -1e4)
261
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
262
+ p_attn = self.drop(p_attn)
263
+ output = torch.matmul(p_attn, value)
264
+ if self.window_size is not None:
265
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
266
+ value_relative_embeddings = self._get_relative_embeddings(
267
+ self.emb_rel_v, t_s
268
+ )
269
+ output = output + self._matmul_with_relative_values(
270
+ relative_weights, value_relative_embeddings
271
+ )
272
+ output = (
273
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
274
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
275
+ return output, p_attn
276
+
277
+ def _matmul_with_relative_values(self, x, y):
278
+ """
279
+ x: [b, h, l, m]
280
+ y: [h or 1, m, d]
281
+ ret: [b, h, l, d]
282
+ """
283
+ ret = torch.matmul(x, y.unsqueeze(0))
284
+ return ret
285
+
286
+ def _matmul_with_relative_keys(self, x, y):
287
+ """
288
+ x: [b, h, l, d]
289
+ y: [h or 1, m, d]
290
+ ret: [b, h, l, m]
291
+ """
292
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
293
+ return ret
294
+
295
+ def _get_relative_embeddings(self, relative_embeddings, length):
296
+ max_relative_position = 2 * self.window_size + 1
297
+ # Pad first before slice to avoid using cond ops.
298
+ pad_length = max(length - (self.window_size + 1), 0)
299
+ slice_start_position = max((self.window_size + 1) - length, 0)
300
+ slice_end_position = slice_start_position + 2 * length - 1
301
+ if pad_length > 0:
302
+ padded_relative_embeddings = F.pad(
303
+ relative_embeddings,
304
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
305
+ )
306
+ else:
307
+ padded_relative_embeddings = relative_embeddings
308
+ used_relative_embeddings = padded_relative_embeddings[
309
+ :, slice_start_position:slice_end_position
310
+ ]
311
+ return used_relative_embeddings
312
+
313
+ def _relative_position_to_absolute_position(self, x):
314
+ """
315
+ x: [b, h, l, 2*l-1]
316
+ ret: [b, h, l, l]
317
+ """
318
+ batch, heads, length, _ = x.size()
319
+ # Concat columns of pad to shift from relative to absolute indexing.
320
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
321
+
322
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
323
+ x_flat = x.view([batch, heads, length * 2 * length])
324
+ x_flat = F.pad(
325
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
326
+ )
327
+
328
+ # Reshape and slice out the padded elements.
329
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
330
+ :, :, :length, length - 1 :
331
+ ]
332
+ return x_final
333
+
334
+ def _absolute_position_to_relative_position(self, x):
335
+ """
336
+ x: [b, h, l, l]
337
+ ret: [b, h, l, 2*l-1]
338
+ """
339
+ batch, heads, length, _ = x.size()
340
+ # padd along column
341
+ x = F.pad(
342
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
343
+ )
344
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
345
+ # add 0's in the beginning that will skew the elements after reshape
346
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
347
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348
+ return x_final
349
+
350
+ def _attention_bias_proximal(self, length):
351
+ """Bias for self-attention to encourage attention to close positions.
352
+ Args:
353
+ length: an integer scalar.
354
+ Returns:
355
+ a Tensor with shape [1, 1, length, length]
356
+ """
357
+ r = torch.arange(length, dtype=torch.float32)
358
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360
+
361
+
362
+ class FFN(nn.Module):
363
+ def __init__(
364
+ self,
365
+ in_channels,
366
+ out_channels,
367
+ filter_channels,
368
+ kernel_size,
369
+ p_dropout=0.0,
370
+ activation=None,
371
+ causal=False,
372
+ ):
373
+ super().__init__()
374
+ self.in_channels = in_channels
375
+ self.out_channels = out_channels
376
+ self.filter_channels = filter_channels
377
+ self.kernel_size = kernel_size
378
+ self.p_dropout = p_dropout
379
+ self.activation = activation
380
+ self.causal = causal
381
+
382
+ if causal:
383
+ self.padding = self._causal_padding
384
+ else:
385
+ self.padding = self._same_padding
386
+
387
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
388
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
389
+ self.drop = nn.Dropout(p_dropout)
390
+
391
+ def forward(self, x, x_mask):
392
+ x = self.conv_1(self.padding(x * x_mask))
393
+ if self.activation == "gelu":
394
+ x = x * torch.sigmoid(1.702 * x)
395
+ else:
396
+ x = torch.relu(x)
397
+ x = self.drop(x)
398
+ x = self.conv_2(self.padding(x * x_mask))
399
+ return x * x_mask
400
+
401
+ def _causal_padding(self, x):
402
+ if self.kernel_size == 1:
403
+ return x
404
+ pad_l = self.kernel_size - 1
405
+ pad_r = 0
406
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
407
+ x = F.pad(x, commons.convert_pad_shape(padding))
408
+ return x
409
+
410
+ def _same_padding(self, x):
411
+ if self.kernel_size == 1:
412
+ return x
413
+ pad_l = (self.kernel_size - 1) // 2
414
+ pad_r = self.kernel_size // 2
415
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
416
+ x = F.pad(x, commons.convert_pad_shape(padding))
417
+ return x
infer_pack/commons.py CHANGED
@@ -1,164 +1,166 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
-
8
- def init_weights(m, mean=0.0, std=0.01):
9
- classname = m.__class__.__name__
10
- if classname.find("Conv") != -1:
11
- m.weight.data.normal_(mean, std)
12
-
13
-
14
- def get_padding(kernel_size, dilation=1):
15
- return int((kernel_size * dilation - dilation) / 2)
16
-
17
-
18
- def convert_pad_shape(pad_shape):
19
- l = pad_shape[::-1]
20
- pad_shape = [item for sublist in l for item in sublist]
21
- return pad_shape
22
-
23
-
24
- def kl_divergence(m_p, logs_p, m_q, logs_q):
25
- """KL(P||Q)"""
26
- kl = (logs_q - logs_p) - 0.5
27
- kl += (
28
- 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
29
- )
30
- return kl
31
-
32
-
33
- def rand_gumbel(shape):
34
- """Sample from the Gumbel distribution, protect from overflows."""
35
- uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
36
- return -torch.log(-torch.log(uniform_samples))
37
-
38
-
39
- def rand_gumbel_like(x):
40
- g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
41
- return g
42
-
43
-
44
- def slice_segments(x, ids_str, segment_size=4):
45
- ret = torch.zeros_like(x[:, :, :segment_size])
46
- for i in range(x.size(0)):
47
- idx_str = ids_str[i]
48
- idx_end = idx_str + segment_size
49
- ret[i] = x[i, :, idx_str:idx_end]
50
- return ret
51
- def slice_segments2(x, ids_str, segment_size=4):
52
- ret = torch.zeros_like(x[:, :segment_size])
53
- for i in range(x.size(0)):
54
- idx_str = ids_str[i]
55
- idx_end = idx_str + segment_size
56
- ret[i] = x[i, idx_str:idx_end]
57
- return ret
58
-
59
-
60
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
61
- b, d, t = x.size()
62
- if x_lengths is None:
63
- x_lengths = t
64
- ids_str_max = x_lengths - segment_size + 1
65
- ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
66
- ret = slice_segments(x, ids_str, segment_size)
67
- return ret, ids_str
68
-
69
-
70
- def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
71
- position = torch.arange(length, dtype=torch.float)
72
- num_timescales = channels // 2
73
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
74
- num_timescales - 1
75
- )
76
- inv_timescales = min_timescale * torch.exp(
77
- torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
78
- )
79
- scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
80
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
81
- signal = F.pad(signal, [0, 0, 0, channels % 2])
82
- signal = signal.view(1, channels, length)
83
- return signal
84
-
85
-
86
- def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
87
- b, channels, length = x.size()
88
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
89
- return x + signal.to(dtype=x.dtype, device=x.device)
90
-
91
-
92
- def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
93
- b, channels, length = x.size()
94
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
95
- return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
96
-
97
-
98
- def subsequent_mask(length):
99
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
100
- return mask
101
-
102
-
103
- @torch.jit.script
104
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
105
- n_channels_int = n_channels[0]
106
- in_act = input_a + input_b
107
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
108
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
109
- acts = t_act * s_act
110
- return acts
111
-
112
-
113
- def convert_pad_shape(pad_shape):
114
- l = pad_shape[::-1]
115
- pad_shape = [item for sublist in l for item in sublist]
116
- return pad_shape
117
-
118
-
119
- def shift_1d(x):
120
- x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
121
- return x
122
-
123
-
124
- def sequence_mask(length, max_length=None):
125
- if max_length is None:
126
- max_length = length.max()
127
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
128
- return x.unsqueeze(0) < length.unsqueeze(1)
129
-
130
-
131
- def generate_path(duration, mask):
132
- """
133
- duration: [b, 1, t_x]
134
- mask: [b, 1, t_y, t_x]
135
- """
136
- device = duration.device
137
-
138
- b, _, t_y, t_x = mask.shape
139
- cum_duration = torch.cumsum(duration, -1)
140
-
141
- cum_duration_flat = cum_duration.view(b * t_x)
142
- path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
143
- path = path.view(b, t_x, t_y)
144
- path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
145
- path = path.unsqueeze(1).transpose(2, 3) * mask
146
- return path
147
-
148
-
149
- def clip_grad_value_(parameters, clip_value, norm_type=2):
150
- if isinstance(parameters, torch.Tensor):
151
- parameters = [parameters]
152
- parameters = list(filter(lambda p: p.grad is not None, parameters))
153
- norm_type = float(norm_type)
154
- if clip_value is not None:
155
- clip_value = float(clip_value)
156
-
157
- total_norm = 0
158
- for p in parameters:
159
- param_norm = p.grad.data.norm(norm_type)
160
- total_norm += param_norm.item() ** norm_type
161
- if clip_value is not None:
162
- p.grad.data.clamp_(min=-clip_value, max=clip_value)
163
- total_norm = total_norm ** (1.0 / norm_type)
164
- return total_norm
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size * dilation - dilation) / 2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
25
+ """KL(P||Q)"""
26
+ kl = (logs_q - logs_p) - 0.5
27
+ kl += (
28
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
29
+ )
30
+ return kl
31
+
32
+
33
+ def rand_gumbel(shape):
34
+ """Sample from the Gumbel distribution, protect from overflows."""
35
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
36
+ return -torch.log(-torch.log(uniform_samples))
37
+
38
+
39
+ def rand_gumbel_like(x):
40
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
41
+ return g
42
+
43
+
44
+ def slice_segments(x, ids_str, segment_size=4):
45
+ ret = torch.zeros_like(x[:, :, :segment_size])
46
+ for i in range(x.size(0)):
47
+ idx_str = ids_str[i]
48
+ idx_end = idx_str + segment_size
49
+ ret[i] = x[i, :, idx_str:idx_end]
50
+ return ret
51
+
52
+
53
+ def slice_segments2(x, ids_str, segment_size=4):
54
+ ret = torch.zeros_like(x[:, :segment_size])
55
+ for i in range(x.size(0)):
56
+ idx_str = ids_str[i]
57
+ idx_end = idx_str + segment_size
58
+ ret[i] = x[i, idx_str:idx_end]
59
+ return ret
60
+
61
+
62
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
63
+ b, d, t = x.size()
64
+ if x_lengths is None:
65
+ x_lengths = t
66
+ ids_str_max = x_lengths - segment_size + 1
67
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
68
+ ret = slice_segments(x, ids_str, segment_size)
69
+ return ret, ids_str
70
+
71
+
72
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
73
+ position = torch.arange(length, dtype=torch.float)
74
+ num_timescales = channels // 2
75
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
76
+ num_timescales - 1
77
+ )
78
+ inv_timescales = min_timescale * torch.exp(
79
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
80
+ )
81
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
82
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
83
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
84
+ signal = signal.view(1, channels, length)
85
+ return signal
86
+
87
+
88
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
89
+ b, channels, length = x.size()
90
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
91
+ return x + signal.to(dtype=x.dtype, device=x.device)
92
+
93
+
94
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
95
+ b, channels, length = x.size()
96
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
97
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
98
+
99
+
100
+ def subsequent_mask(length):
101
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
102
+ return mask
103
+
104
+
105
+ @torch.jit.script
106
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
107
+ n_channels_int = n_channels[0]
108
+ in_act = input_a + input_b
109
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
110
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
111
+ acts = t_act * s_act
112
+ return acts
113
+
114
+
115
+ def convert_pad_shape(pad_shape):
116
+ l = pad_shape[::-1]
117
+ pad_shape = [item for sublist in l for item in sublist]
118
+ return pad_shape
119
+
120
+
121
+ def shift_1d(x):
122
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
123
+ return x
124
+
125
+
126
+ def sequence_mask(length, max_length=None):
127
+ if max_length is None:
128
+ max_length = length.max()
129
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
130
+ return x.unsqueeze(0) < length.unsqueeze(1)
131
+
132
+
133
+ def generate_path(duration, mask):
134
+ """
135
+ duration: [b, 1, t_x]
136
+ mask: [b, 1, t_y, t_x]
137
+ """
138
+ device = duration.device
139
+
140
+ b, _, t_y, t_x = mask.shape
141
+ cum_duration = torch.cumsum(duration, -1)
142
+
143
+ cum_duration_flat = cum_duration.view(b * t_x)
144
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
145
+ path = path.view(b, t_x, t_y)
146
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
147
+ path = path.unsqueeze(1).transpose(2, 3) * mask
148
+ return path
149
+
150
+
151
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
152
+ if isinstance(parameters, torch.Tensor):
153
+ parameters = [parameters]
154
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
155
+ norm_type = float(norm_type)
156
+ if clip_value is not None:
157
+ clip_value = float(clip_value)
158
+
159
+ total_norm = 0
160
+ for p in parameters:
161
+ param_norm = p.grad.data.norm(norm_type)
162
+ total_norm += param_norm.item() ** norm_type
163
+ if clip_value is not None:
164
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
165
+ total_norm = total_norm ** (1.0 / norm_type)
166
+ return total_norm
infer_pack/models.py CHANGED
@@ -1,892 +1,1116 @@
1
- import math,pdb,os
2
- from time import time as ttime
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from infer_pack import modules
7
- from infer_pack import attentions
8
- from infer_pack import commons
9
- from infer_pack.commons import init_weights, get_padding
10
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
- from infer_pack.commons import init_weights
13
- import numpy as np
14
- from infer_pack import commons
15
- class TextEncoder256(nn.Module):
16
- def __init__(
17
- self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True ):
18
- super().__init__()
19
- self.out_channels = out_channels
20
- self.hidden_channels = hidden_channels
21
- self.filter_channels = filter_channels
22
- self.n_heads = n_heads
23
- self.n_layers = n_layers
24
- self.kernel_size = kernel_size
25
- self.p_dropout = p_dropout
26
- self.emb_phone = nn.Linear(256, hidden_channels)
27
- self.lrelu=nn.LeakyReLU(0.1,inplace=True)
28
- if(f0==True):
29
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
30
- self.encoder = attentions.Encoder(
31
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
32
- )
33
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
34
-
35
- def forward(self, phone, pitch, lengths):
36
- if(pitch==None):
37
- x = self.emb_phone(phone)
38
- else:
39
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
40
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
41
- x=self.lrelu(x)
42
- x = torch.transpose(x, 1, -1) # [b, h, t]
43
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
44
- x.dtype
45
- )
46
- x = self.encoder(x * x_mask, x_mask)
47
- stats = self.proj(x) * x_mask
48
-
49
- m, logs = torch.split(stats, self.out_channels, dim=1)
50
- return m, logs, x_mask
51
- class TextEncoder256Sim(nn.Module):
52
- def __init__( self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True):
53
- super().__init__()
54
- self.out_channels = out_channels
55
- self.hidden_channels = hidden_channels
56
- self.filter_channels = filter_channels
57
- self.n_heads = n_heads
58
- self.n_layers = n_layers
59
- self.kernel_size = kernel_size
60
- self.p_dropout = p_dropout
61
- self.emb_phone = nn.Linear(256, hidden_channels)
62
- self.lrelu=nn.LeakyReLU(0.1,inplace=True)
63
- if(f0==True):
64
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
65
- self.encoder = attentions.Encoder(
66
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
67
- )
68
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
69
-
70
- def forward(self, phone, pitch, lengths):
71
- if(pitch==None):
72
- x = self.emb_phone(phone)
73
- else:
74
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
75
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
76
- x=self.lrelu(x)
77
- x = torch.transpose(x, 1, -1) # [b, h, t]
78
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
79
- x = self.encoder(x * x_mask, x_mask)
80
- x = self.proj(x) * x_mask
81
- return x,x_mask
82
- class ResidualCouplingBlock(nn.Module):
83
- def __init__(
84
- self,
85
- channels,
86
- hidden_channels,
87
- kernel_size,
88
- dilation_rate,
89
- n_layers,
90
- n_flows=4,
91
- gin_channels=0,
92
- ):
93
- super().__init__()
94
- self.channels = channels
95
- self.hidden_channels = hidden_channels
96
- self.kernel_size = kernel_size
97
- self.dilation_rate = dilation_rate
98
- self.n_layers = n_layers
99
- self.n_flows = n_flows
100
- self.gin_channels = gin_channels
101
-
102
- self.flows = nn.ModuleList()
103
- for i in range(n_flows):
104
- self.flows.append(
105
- modules.ResidualCouplingLayer(
106
- channels,
107
- hidden_channels,
108
- kernel_size,
109
- dilation_rate,
110
- n_layers,
111
- gin_channels=gin_channels,
112
- mean_only=True,
113
- )
114
- )
115
- self.flows.append(modules.Flip())
116
-
117
- def forward(self, x, x_mask, g=None, reverse=False):
118
- if not reverse:
119
- for flow in self.flows:
120
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
121
- else:
122
- for flow in reversed(self.flows):
123
- x = flow(x, x_mask, g=g, reverse=reverse)
124
- return x
125
-
126
- def remove_weight_norm(self):
127
- for i in range(self.n_flows):
128
- self.flows[i * 2].remove_weight_norm()
129
- class PosteriorEncoder(nn.Module):
130
- def __init__(
131
- self,
132
- in_channels,
133
- out_channels,
134
- hidden_channels,
135
- kernel_size,
136
- dilation_rate,
137
- n_layers,
138
- gin_channels=0,
139
- ):
140
- super().__init__()
141
- self.in_channels = in_channels
142
- self.out_channels = out_channels
143
- self.hidden_channels = hidden_channels
144
- self.kernel_size = kernel_size
145
- self.dilation_rate = dilation_rate
146
- self.n_layers = n_layers
147
- self.gin_channels = gin_channels
148
-
149
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
150
- self.enc = modules.WN(
151
- hidden_channels,
152
- kernel_size,
153
- dilation_rate,
154
- n_layers,
155
- gin_channels=gin_channels,
156
- )
157
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
158
-
159
- def forward(self, x, x_lengths, g=None):
160
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
161
- x.dtype
162
- )
163
- x = self.pre(x) * x_mask
164
- x = self.enc(x, x_mask, g=g)
165
- stats = self.proj(x) * x_mask
166
- m, logs = torch.split(stats, self.out_channels, dim=1)
167
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
168
- return z, m, logs, x_mask
169
-
170
- def remove_weight_norm(self):
171
- self.enc.remove_weight_norm()
172
- class Generator(torch.nn.Module):
173
- def __init__(
174
- self,
175
- initial_channel,
176
- resblock,
177
- resblock_kernel_sizes,
178
- resblock_dilation_sizes,
179
- upsample_rates,
180
- upsample_initial_channel,
181
- upsample_kernel_sizes,
182
- gin_channels=0,
183
- ):
184
- super(Generator, self).__init__()
185
- self.num_kernels = len(resblock_kernel_sizes)
186
- self.num_upsamples = len(upsample_rates)
187
- self.conv_pre = Conv1d(
188
- initial_channel, upsample_initial_channel, 7, 1, padding=3
189
- )
190
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
191
-
192
- self.ups = nn.ModuleList()
193
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
194
- self.ups.append(
195
- weight_norm(
196
- ConvTranspose1d(
197
- upsample_initial_channel // (2**i),
198
- upsample_initial_channel // (2 ** (i + 1)),
199
- k,
200
- u,
201
- padding=(k - u) // 2,
202
- )
203
- )
204
- )
205
-
206
- self.resblocks = nn.ModuleList()
207
- for i in range(len(self.ups)):
208
- ch = upsample_initial_channel // (2 ** (i + 1))
209
- for j, (k, d) in enumerate(
210
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
211
- ):
212
- self.resblocks.append(resblock(ch, k, d))
213
-
214
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
215
- self.ups.apply(init_weights)
216
-
217
- if gin_channels != 0:
218
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
219
-
220
- def forward(self, x, g=None):
221
- x = self.conv_pre(x)
222
- if g is not None:
223
- x = x + self.cond(g)
224
-
225
- for i in range(self.num_upsamples):
226
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
227
- x = self.ups[i](x)
228
- xs = None
229
- for j in range(self.num_kernels):
230
- if xs is None:
231
- xs = self.resblocks[i * self.num_kernels + j](x)
232
- else:
233
- xs += self.resblocks[i * self.num_kernels + j](x)
234
- x = xs / self.num_kernels
235
- x = F.leaky_relu(x)
236
- x = self.conv_post(x)
237
- x = torch.tanh(x)
238
-
239
- return x
240
-
241
- def remove_weight_norm(self):
242
- for l in self.ups:
243
- remove_weight_norm(l)
244
- for l in self.resblocks:
245
- l.remove_weight_norm()
246
- class SineGen(torch.nn.Module):
247
- """ Definition of sine generator
248
- SineGen(samp_rate, harmonic_num = 0,
249
- sine_amp = 0.1, noise_std = 0.003,
250
- voiced_threshold = 0,
251
- flag_for_pulse=False)
252
- samp_rate: sampling rate in Hz
253
- harmonic_num: number of harmonic overtones (default 0)
254
- sine_amp: amplitude of sine-wavefrom (default 0.1)
255
- noise_std: std of Gaussian noise (default 0.003)
256
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
257
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
258
- Note: when flag_for_pulse is True, the first time step of a voiced
259
- segment is always sin(np.pi) or cos(0)
260
- """
261
-
262
- def __init__(self, samp_rate, harmonic_num=0,
263
- sine_amp=0.1, noise_std=0.003,
264
- voiced_threshold=0,
265
- flag_for_pulse=False):
266
- super(SineGen, self).__init__()
267
- self.sine_amp = sine_amp
268
- self.noise_std = noise_std
269
- self.harmonic_num = harmonic_num
270
- self.dim = self.harmonic_num + 1
271
- self.sampling_rate = samp_rate
272
- self.voiced_threshold = voiced_threshold
273
-
274
- def _f02uv(self, f0):
275
- # generate uv signal
276
- uv = torch.ones_like(f0)
277
- uv = uv * (f0 > self.voiced_threshold)
278
- return uv
279
-
280
- def forward(self, f0,upp):
281
- """ sine_tensor, uv = forward(f0)
282
- input F0: tensor(batchsize=1, length, dim=1)
283
- f0 for unvoiced steps should be 0
284
- output sine_tensor: tensor(batchsize=1, length, dim)
285
- output uv: tensor(batchsize=1, length, 1)
286
- """
287
- with torch.no_grad():
288
- f0 = f0[:, None].transpose(1, 2)
289
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,device=f0.device)
290
- # fundamental component
291
- f0_buf[:, :, 0] = f0[:, :, 0]
292
- for idx in np.arange(self.harmonic_num):f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)# idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
293
- rad_values = (f0_buf / self.sampling_rate) % 1###%1意味着n_har的乘积无法后处理优化
294
- rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device)
295
- rand_ini[:, 0] = 0
296
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
297
- tmp_over_one = torch.cumsum(rad_values, 1)# % 1 #####%1意味着后面的cumsum无法再优化
298
- tmp_over_one*=upp
299
- tmp_over_one=F.interpolate(tmp_over_one.transpose(2, 1), scale_factor=upp, mode='linear', align_corners=True).transpose(2, 1)
300
- rad_values=F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)#######
301
- tmp_over_one%=1
302
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
303
- cumsum_shift = torch.zeros_like(rad_values)
304
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
305
- sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
306
- sine_waves = sine_waves * self.sine_amp
307
- uv = self._f02uv(f0)
308
- uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
309
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
310
- noise = noise_amp * torch.randn_like(sine_waves)
311
- sine_waves = sine_waves * uv + noise
312
- return sine_waves, uv, noise
313
- class SourceModuleHnNSF(torch.nn.Module):
314
- """ SourceModule for hn-nsf
315
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
316
- add_noise_std=0.003, voiced_threshod=0)
317
- sampling_rate: sampling_rate in Hz
318
- harmonic_num: number of harmonic above F0 (default: 0)
319
- sine_amp: amplitude of sine source signal (default: 0.1)
320
- add_noise_std: std of additive Gaussian noise (default: 0.003)
321
- note that amplitude of noise in unvoiced is decided
322
- by sine_amp
323
- voiced_threshold: threhold to set U/V given F0 (default: 0)
324
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
325
- F0_sampled (batchsize, length, 1)
326
- Sine_source (batchsize, length, 1)
327
- noise_source (batchsize, length 1)
328
- uv (batchsize, length, 1)
329
- """
330
-
331
- def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
332
- add_noise_std=0.003, voiced_threshod=0,is_half=True):
333
- super(SourceModuleHnNSF, self).__init__()
334
-
335
- self.sine_amp = sine_amp
336
- self.noise_std = add_noise_std
337
- self.is_half=is_half
338
- # to produce sine waveforms
339
- self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
340
- sine_amp, add_noise_std, voiced_threshod)
341
-
342
- # to merge source harmonics into a single excitation
343
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
344
- self.l_tanh = torch.nn.Tanh()
345
-
346
- def forward(self, x,upp=None):
347
- sine_wavs, uv, _ = self.l_sin_gen(x,upp)
348
- if(self.is_half):sine_wavs=sine_wavs.half()
349
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
350
- return sine_merge,None,None# noise, uv
351
- class GeneratorNSF(torch.nn.Module):
352
- def __init__(
353
- self,
354
- initial_channel,
355
- resblock,
356
- resblock_kernel_sizes,
357
- resblock_dilation_sizes,
358
- upsample_rates,
359
- upsample_initial_channel,
360
- upsample_kernel_sizes,
361
- gin_channels,
362
- sr,
363
- is_half=False
364
- ):
365
- super(GeneratorNSF, self).__init__()
366
- self.num_kernels = len(resblock_kernel_sizes)
367
- self.num_upsamples = len(upsample_rates)
368
-
369
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
370
- self.m_source = SourceModuleHnNSF(
371
- sampling_rate=sr,
372
- harmonic_num=0,
373
- is_half=is_half
374
- )
375
- self.noise_convs = nn.ModuleList()
376
- self.conv_pre = Conv1d(
377
- initial_channel, upsample_initial_channel, 7, 1, padding=3
378
- )
379
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
380
-
381
- self.ups = nn.ModuleList()
382
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
383
- c_cur = upsample_initial_channel // (2 ** (i + 1))
384
- self.ups.append(
385
- weight_norm(
386
- ConvTranspose1d(
387
- upsample_initial_channel // (2**i),
388
- upsample_initial_channel // (2 ** (i + 1)),
389
- k,
390
- u,
391
- padding=(k - u) // 2,
392
- )
393
- )
394
- )
395
- if i + 1 < len(upsample_rates):
396
- stride_f0 = np.prod(upsample_rates[i + 1:])
397
- self.noise_convs.append(Conv1d(
398
- 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
399
- else:
400
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
401
-
402
- self.resblocks = nn.ModuleList()
403
- for i in range(len(self.ups)):
404
- ch = upsample_initial_channel // (2 ** (i + 1))
405
- for j, (k, d) in enumerate(
406
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
407
- ):
408
- self.resblocks.append(resblock(ch, k, d))
409
-
410
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
411
- self.ups.apply(init_weights)
412
-
413
- if gin_channels != 0:
414
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
415
-
416
- self.upp=np.prod(upsample_rates)
417
-
418
- def forward(self, x, f0,g=None):
419
- har_source, noi_source, uv = self.m_source(f0,self.upp)
420
- har_source = har_source.transpose(1, 2)
421
- x = self.conv_pre(x)
422
- if g is not None:
423
- x = x + self.cond(g)
424
-
425
- for i in range(self.num_upsamples):
426
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
427
- x = self.ups[i](x)
428
- x_source = self.noise_convs[i](har_source)
429
- x = x + x_source
430
- xs = None
431
- for j in range(self.num_kernels):
432
- if xs is None:
433
- xs = self.resblocks[i * self.num_kernels + j](x)
434
- else:
435
- xs += self.resblocks[i * self.num_kernels + j](x)
436
- x = xs / self.num_kernels
437
- x = F.leaky_relu(x)
438
- x = self.conv_post(x)
439
- x = torch.tanh(x)
440
- return x
441
-
442
- def remove_weight_norm(self):
443
- for l in self.ups:
444
- remove_weight_norm(l)
445
- for l in self.resblocks:
446
- l.remove_weight_norm()
447
- sr2sr={
448
- "32k":32000,
449
- "40k":40000,
450
- "48k":48000,
451
- }
452
- class SynthesizerTrnMs256NSFsid(nn.Module):
453
- def __init__(
454
- self,
455
- spec_channels,
456
- segment_size,
457
- inter_channels,
458
- hidden_channels,
459
- filter_channels,
460
- n_heads,
461
- n_layers,
462
- kernel_size,
463
- p_dropout,
464
- resblock,
465
- resblock_kernel_sizes,
466
- resblock_dilation_sizes,
467
- upsample_rates,
468
- upsample_initial_channel,
469
- upsample_kernel_sizes,
470
- spk_embed_dim,
471
- gin_channels,
472
- sr,
473
- **kwargs
474
- ):
475
-
476
- super().__init__()
477
- if(type(sr)==type("strr")):
478
- sr=sr2sr[sr]
479
- self.spec_channels = spec_channels
480
- self.inter_channels = inter_channels
481
- self.hidden_channels = hidden_channels
482
- self.filter_channels = filter_channels
483
- self.n_heads = n_heads
484
- self.n_layers = n_layers
485
- self.kernel_size = kernel_size
486
- self.p_dropout = p_dropout
487
- self.resblock = resblock
488
- self.resblock_kernel_sizes = resblock_kernel_sizes
489
- self.resblock_dilation_sizes = resblock_dilation_sizes
490
- self.upsample_rates = upsample_rates
491
- self.upsample_initial_channel = upsample_initial_channel
492
- self.upsample_kernel_sizes = upsample_kernel_sizes
493
- self.segment_size = segment_size
494
- self.gin_channels = gin_channels
495
- # self.hop_length = hop_length#
496
- self.spk_embed_dim=spk_embed_dim
497
- self.enc_p = TextEncoder256(
498
- inter_channels,
499
- hidden_channels,
500
- filter_channels,
501
- n_heads,
502
- n_layers,
503
- kernel_size,
504
- p_dropout,
505
- )
506
- self.dec = GeneratorNSF(
507
- inter_channels,
508
- resblock,
509
- resblock_kernel_sizes,
510
- resblock_dilation_sizes,
511
- upsample_rates,
512
- upsample_initial_channel,
513
- upsample_kernel_sizes,
514
- gin_channels=gin_channels, sr=sr, is_half=kwargs["is_half"]
515
- )
516
- self.enc_q = PosteriorEncoder(
517
- spec_channels,
518
- inter_channels,
519
- hidden_channels,
520
- 5,
521
- 1,
522
- 16,
523
- gin_channels=gin_channels,
524
- )
525
- self.flow = ResidualCouplingBlock(
526
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
527
- )
528
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
529
- print("gin_channels:",gin_channels,"self.spk_embed_dim:",self.spk_embed_dim)
530
- def remove_weight_norm(self):
531
- self.dec.remove_weight_norm()
532
- self.flow.remove_weight_norm()
533
- self.enc_q.remove_weight_norm()
534
-
535
- def forward(self, phone, phone_lengths, pitch,pitchf, y, y_lengths,ds):#这里ds是id,[bs,1]
536
- # print(1,pitch.shape)#[bs,t]
537
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
538
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
539
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
540
- z_p = self.flow(z, y_mask, g=g)
541
- z_slice, ids_slice = commons.rand_slice_segments(
542
- z, y_lengths, self.segment_size
543
- )
544
- # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
545
- pitchf = commons.slice_segments2(
546
- pitchf, ids_slice, self.segment_size
547
- )
548
- # print(-2,pitchf.shape,z_slice.shape)
549
- o = self.dec(z_slice,pitchf, g=g)
550
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
551
-
552
- def infer(self, phone, phone_lengths, pitch, nsff0,sid,max_len=None):
553
- g = self.emb_g(sid).unsqueeze(-1)
554
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
555
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
556
- z = self.flow(z_p, x_mask, g=g, reverse=True)
557
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0,g=g)
558
- return o, x_mask, (z, z_p, m_p, logs_p)
559
- class SynthesizerTrnMs256NSFsid_nono(nn.Module):
560
- def __init__(
561
- self,
562
- spec_channels,
563
- segment_size,
564
- inter_channels,
565
- hidden_channels,
566
- filter_channels,
567
- n_heads,
568
- n_layers,
569
- kernel_size,
570
- p_dropout,
571
- resblock,
572
- resblock_kernel_sizes,
573
- resblock_dilation_sizes,
574
- upsample_rates,
575
- upsample_initial_channel,
576
- upsample_kernel_sizes,
577
- spk_embed_dim,
578
- gin_channels,
579
- sr=None,
580
- **kwargs
581
- ):
582
-
583
- super().__init__()
584
- self.spec_channels = spec_channels
585
- self.inter_channels = inter_channels
586
- self.hidden_channels = hidden_channels
587
- self.filter_channels = filter_channels
588
- self.n_heads = n_heads
589
- self.n_layers = n_layers
590
- self.kernel_size = kernel_size
591
- self.p_dropout = p_dropout
592
- self.resblock = resblock
593
- self.resblock_kernel_sizes = resblock_kernel_sizes
594
- self.resblock_dilation_sizes = resblock_dilation_sizes
595
- self.upsample_rates = upsample_rates
596
- self.upsample_initial_channel = upsample_initial_channel
597
- self.upsample_kernel_sizes = upsample_kernel_sizes
598
- self.segment_size = segment_size
599
- self.gin_channels = gin_channels
600
- # self.hop_length = hop_length#
601
- self.spk_embed_dim=spk_embed_dim
602
- self.enc_p = TextEncoder256(
603
- inter_channels,
604
- hidden_channels,
605
- filter_channels,
606
- n_heads,
607
- n_layers,
608
- kernel_size,
609
- p_dropout,f0=False
610
- )
611
- self.dec = Generator(
612
- inter_channels,
613
- resblock,
614
- resblock_kernel_sizes,
615
- resblock_dilation_sizes,
616
- upsample_rates,
617
- upsample_initial_channel,
618
- upsample_kernel_sizes,
619
- gin_channels=gin_channels
620
- )
621
- self.enc_q = PosteriorEncoder(
622
- spec_channels,
623
- inter_channels,
624
- hidden_channels,
625
- 5,
626
- 1,
627
- 16,
628
- gin_channels=gin_channels,
629
- )
630
- self.flow = ResidualCouplingBlock(
631
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
632
- )
633
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
634
- print("gin_channels:",gin_channels,"self.spk_embed_dim:",self.spk_embed_dim)
635
-
636
- def remove_weight_norm(self):
637
- self.dec.remove_weight_norm()
638
- self.flow.remove_weight_norm()
639
- self.enc_q.remove_weight_norm()
640
-
641
- def forward(self, phone, phone_lengths, y, y_lengths,ds):#这里ds是id,[bs,1]
642
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
643
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
644
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
645
- z_p = self.flow(z, y_mask, g=g)
646
- z_slice, ids_slice = commons.rand_slice_segments(
647
- z, y_lengths, self.segment_size
648
- )
649
- o = self.dec(z_slice, g=g)
650
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
651
-
652
- def infer(self, phone, phone_lengths,sid,max_len=None):
653
- g = self.emb_g(sid).unsqueeze(-1)
654
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
655
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
656
- z = self.flow(z_p, x_mask, g=g, reverse=True)
657
- o = self.dec((z * x_mask)[:, :, :max_len],g=g)
658
- return o, x_mask, (z, z_p, m_p, logs_p)
659
- class SynthesizerTrnMs256NSFsid_sim(nn.Module):
660
- """
661
- Synthesizer for Training
662
- """
663
-
664
- def __init__(
665
- self,
666
- spec_channels,
667
- segment_size,
668
- inter_channels,
669
- hidden_channels,
670
- filter_channels,
671
- n_heads,
672
- n_layers,
673
- kernel_size,
674
- p_dropout,
675
- resblock,
676
- resblock_kernel_sizes,
677
- resblock_dilation_sizes,
678
- upsample_rates,
679
- upsample_initial_channel,
680
- upsample_kernel_sizes,
681
- spk_embed_dim,
682
- # hop_length,
683
- gin_channels=0,
684
- use_sdp=True,
685
- **kwargs
686
- ):
687
-
688
- super().__init__()
689
- self.spec_channels = spec_channels
690
- self.inter_channels = inter_channels
691
- self.hidden_channels = hidden_channels
692
- self.filter_channels = filter_channels
693
- self.n_heads = n_heads
694
- self.n_layers = n_layers
695
- self.kernel_size = kernel_size
696
- self.p_dropout = p_dropout
697
- self.resblock = resblock
698
- self.resblock_kernel_sizes = resblock_kernel_sizes
699
- self.resblock_dilation_sizes = resblock_dilation_sizes
700
- self.upsample_rates = upsample_rates
701
- self.upsample_initial_channel = upsample_initial_channel
702
- self.upsample_kernel_sizes = upsample_kernel_sizes
703
- self.segment_size = segment_size
704
- self.gin_channels = gin_channels
705
- # self.hop_length = hop_length#
706
- self.spk_embed_dim=spk_embed_dim
707
- self.enc_p = TextEncoder256Sim(
708
- inter_channels,
709
- hidden_channels,
710
- filter_channels,
711
- n_heads,
712
- n_layers,
713
- kernel_size,
714
- p_dropout,
715
- )
716
- self.dec = GeneratorNSF(
717
- inter_channels,
718
- resblock,
719
- resblock_kernel_sizes,
720
- resblock_dilation_sizes,
721
- upsample_rates,
722
- upsample_initial_channel,
723
- upsample_kernel_sizes,
724
- gin_channels=gin_channels,is_half=kwargs["is_half"]
725
- )
726
-
727
- self.flow = ResidualCouplingBlock(
728
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
729
- )
730
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
731
- print("gin_channels:",gin_channels,"self.spk_embed_dim:",self.spk_embed_dim)
732
- def remove_weight_norm(self):
733
- self.dec.remove_weight_norm()
734
- self.flow.remove_weight_norm()
735
- self.enc_q.remove_weight_norm()
736
-
737
- def forward(self, phone, phone_lengths, pitch, pitchf, y_lengths,ds): # y是spec不需要了现在
738
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
739
- x, x_mask = self.enc_p(phone, pitch, phone_lengths)
740
- x = self.flow(x, x_mask, g=g, reverse=True)
741
- z_slice, ids_slice = commons.rand_slice_segments(
742
- x, y_lengths, self.segment_size
743
- )
744
-
745
- pitchf = commons.slice_segments2(
746
- pitchf, ids_slice, self.segment_size
747
- )
748
- o = self.dec(z_slice, pitchf, g=g)
749
- return o, ids_slice
750
- def infer(self, phone, phone_lengths, pitch, pitchf, ds,max_len=None): # y是spec不需要了现在
751
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
752
- x, x_mask = self.enc_p(phone, pitch, phone_lengths)
753
- x = self.flow(x, x_mask, g=g, reverse=True)
754
- o = self.dec((x*x_mask)[:, :, :max_len], pitchf, g=g)
755
- return o, o
756
-
757
- class MultiPeriodDiscriminator(torch.nn.Module):
758
- def __init__(self, use_spectral_norm=False):
759
- super(MultiPeriodDiscriminator, self).__init__()
760
- periods = [2, 3, 5, 7, 11,17]
761
- # periods = [3, 5, 7, 11, 17, 23, 37]
762
-
763
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
764
- discs = discs + [
765
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
766
- ]
767
- self.discriminators = nn.ModuleList(discs)
768
-
769
- def forward(self, y, y_hat):
770
- y_d_rs = []#
771
- y_d_gs = []
772
- fmap_rs = []
773
- fmap_gs = []
774
- for i, d in enumerate(self.discriminators):
775
- y_d_r, fmap_r = d(y)
776
- y_d_g, fmap_g = d(y_hat)
777
- # for j in range(len(fmap_r)):
778
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
779
- y_d_rs.append(y_d_r)
780
- y_d_gs.append(y_d_g)
781
- fmap_rs.append(fmap_r)
782
- fmap_gs.append(fmap_g)
783
-
784
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
785
-
786
- class DiscriminatorS(torch.nn.Module):
787
- def __init__(self, use_spectral_norm=False):
788
- super(DiscriminatorS, self).__init__()
789
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
790
- self.convs = nn.ModuleList(
791
- [
792
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
793
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
794
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
795
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
796
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
797
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
798
- ]
799
- )
800
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
801
-
802
- def forward(self, x):
803
- fmap = []
804
-
805
- for l in self.convs:
806
- x = l(x)
807
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
808
- fmap.append(x)
809
- x = self.conv_post(x)
810
- fmap.append(x)
811
- x = torch.flatten(x, 1, -1)
812
-
813
- return x, fmap
814
-
815
- class DiscriminatorP(torch.nn.Module):
816
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
817
- super(DiscriminatorP, self).__init__()
818
- self.period = period
819
- self.use_spectral_norm = use_spectral_norm
820
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
821
- self.convs = nn.ModuleList(
822
- [
823
- norm_f(
824
- Conv2d(
825
- 1,
826
- 32,
827
- (kernel_size, 1),
828
- (stride, 1),
829
- padding=(get_padding(kernel_size, 1), 0),
830
- )
831
- ),
832
- norm_f(
833
- Conv2d(
834
- 32,
835
- 128,
836
- (kernel_size, 1),
837
- (stride, 1),
838
- padding=(get_padding(kernel_size, 1), 0),
839
- )
840
- ),
841
- norm_f(
842
- Conv2d(
843
- 128,
844
- 512,
845
- (kernel_size, 1),
846
- (stride, 1),
847
- padding=(get_padding(kernel_size, 1), 0),
848
- )
849
- ),
850
- norm_f(
851
- Conv2d(
852
- 512,
853
- 1024,
854
- (kernel_size, 1),
855
- (stride, 1),
856
- padding=(get_padding(kernel_size, 1), 0),
857
- )
858
- ),
859
- norm_f(
860
- Conv2d(
861
- 1024,
862
- 1024,
863
- (kernel_size, 1),
864
- 1,
865
- padding=(get_padding(kernel_size, 1), 0),
866
- )
867
- ),
868
- ]
869
- )
870
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
871
-
872
- def forward(self, x):
873
- fmap = []
874
-
875
- # 1d to 2d
876
- b, c, t = x.shape
877
- if t % self.period != 0: # pad first
878
- n_pad = self.period - (t % self.period)
879
- x = F.pad(x, (0, n_pad), "reflect")
880
- t = t + n_pad
881
- x = x.view(b, c, t // self.period, self.period)
882
-
883
- for l in self.convs:
884
- x = l(x)
885
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
886
- fmap.append(x)
887
- x = self.conv_post(x)
888
- fmap.append(x)
889
- x = torch.flatten(x, 1, -1)
890
-
891
- return x, fmap
892
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, pdb, os
2
+ from time import time as ttime
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from infer_pack import modules
7
+ from infer_pack import attentions
8
+ from infer_pack import commons
9
+ from infer_pack.commons import init_weights, get_padding
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from infer_pack.commons import init_weights
13
+ import numpy as np
14
+ from infer_pack import commons
15
+
16
+
17
+ class TextEncoder256(nn.Module):
18
+ def __init__(
19
+ self,
20
+ out_channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ f0=True,
28
+ ):
29
+ super().__init__()
30
+ self.out_channels = out_channels
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.emb_phone = nn.Linear(256, hidden_channels)
38
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
+ if f0 == True:
40
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
+ self.encoder = attentions.Encoder(
42
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
+ )
44
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
+
46
+ def forward(self, phone, pitch, lengths):
47
+ if pitch == None:
48
+ x = self.emb_phone(phone)
49
+ else:
50
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
+ x = self.lrelu(x)
53
+ x = torch.transpose(x, 1, -1) # [b, h, t]
54
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
+ x.dtype
56
+ )
57
+ x = self.encoder(x * x_mask, x_mask)
58
+ stats = self.proj(x) * x_mask
59
+
60
+ m, logs = torch.split(stats, self.out_channels, dim=1)
61
+ return m, logs, x_mask
62
+ class TextEncoder768(nn.Module):
63
+ def __init__(
64
+ self,
65
+ out_channels,
66
+ hidden_channels,
67
+ filter_channels,
68
+ n_heads,
69
+ n_layers,
70
+ kernel_size,
71
+ p_dropout,
72
+ f0=True,
73
+ ):
74
+ super().__init__()
75
+ self.out_channels = out_channels
76
+ self.hidden_channels = hidden_channels
77
+ self.filter_channels = filter_channels
78
+ self.n_heads = n_heads
79
+ self.n_layers = n_layers
80
+ self.kernel_size = kernel_size
81
+ self.p_dropout = p_dropout
82
+ self.emb_phone = nn.Linear(768, hidden_channels)
83
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
84
+ if f0 == True:
85
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
86
+ self.encoder = attentions.Encoder(
87
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
88
+ )
89
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
90
+
91
+ def forward(self, phone, pitch, lengths):
92
+ if pitch == None:
93
+ x = self.emb_phone(phone)
94
+ else:
95
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
96
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
97
+ x = self.lrelu(x)
98
+ x = torch.transpose(x, 1, -1) # [b, h, t]
99
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
100
+ x.dtype
101
+ )
102
+ x = self.encoder(x * x_mask, x_mask)
103
+ stats = self.proj(x) * x_mask
104
+
105
+ m, logs = torch.split(stats, self.out_channels, dim=1)
106
+ return m, logs, x_mask
107
+
108
+ class ResidualCouplingBlock(nn.Module):
109
+ def __init__(
110
+ self,
111
+ channels,
112
+ hidden_channels,
113
+ kernel_size,
114
+ dilation_rate,
115
+ n_layers,
116
+ n_flows=4,
117
+ gin_channels=0,
118
+ ):
119
+ super().__init__()
120
+ self.channels = channels
121
+ self.hidden_channels = hidden_channels
122
+ self.kernel_size = kernel_size
123
+ self.dilation_rate = dilation_rate
124
+ self.n_layers = n_layers
125
+ self.n_flows = n_flows
126
+ self.gin_channels = gin_channels
127
+
128
+ self.flows = nn.ModuleList()
129
+ for i in range(n_flows):
130
+ self.flows.append(
131
+ modules.ResidualCouplingLayer(
132
+ channels,
133
+ hidden_channels,
134
+ kernel_size,
135
+ dilation_rate,
136
+ n_layers,
137
+ gin_channels=gin_channels,
138
+ mean_only=True,
139
+ )
140
+ )
141
+ self.flows.append(modules.Flip())
142
+
143
+ def forward(self, x, x_mask, g=None, reverse=False):
144
+ if not reverse:
145
+ for flow in self.flows:
146
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
147
+ else:
148
+ for flow in reversed(self.flows):
149
+ x = flow(x, x_mask, g=g, reverse=reverse)
150
+ return x
151
+
152
+ def remove_weight_norm(self):
153
+ for i in range(self.n_flows):
154
+ self.flows[i * 2].remove_weight_norm()
155
+
156
+
157
+ class PosteriorEncoder(nn.Module):
158
+ def __init__(
159
+ self,
160
+ in_channels,
161
+ out_channels,
162
+ hidden_channels,
163
+ kernel_size,
164
+ dilation_rate,
165
+ n_layers,
166
+ gin_channels=0,
167
+ ):
168
+ super().__init__()
169
+ self.in_channels = in_channels
170
+ self.out_channels = out_channels
171
+ self.hidden_channels = hidden_channels
172
+ self.kernel_size = kernel_size
173
+ self.dilation_rate = dilation_rate
174
+ self.n_layers = n_layers
175
+ self.gin_channels = gin_channels
176
+
177
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
178
+ self.enc = modules.WN(
179
+ hidden_channels,
180
+ kernel_size,
181
+ dilation_rate,
182
+ n_layers,
183
+ gin_channels=gin_channels,
184
+ )
185
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
186
+
187
+ def forward(self, x, x_lengths, g=None):
188
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
189
+ x.dtype
190
+ )
191
+ x = self.pre(x) * x_mask
192
+ x = self.enc(x, x_mask, g=g)
193
+ stats = self.proj(x) * x_mask
194
+ m, logs = torch.split(stats, self.out_channels, dim=1)
195
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
196
+ return z, m, logs, x_mask
197
+
198
+ def remove_weight_norm(self):
199
+ self.enc.remove_weight_norm()
200
+
201
+
202
+ class Generator(torch.nn.Module):
203
+ def __init__(
204
+ self,
205
+ initial_channel,
206
+ resblock,
207
+ resblock_kernel_sizes,
208
+ resblock_dilation_sizes,
209
+ upsample_rates,
210
+ upsample_initial_channel,
211
+ upsample_kernel_sizes,
212
+ gin_channels=0,
213
+ ):
214
+ super(Generator, self).__init__()
215
+ self.num_kernels = len(resblock_kernel_sizes)
216
+ self.num_upsamples = len(upsample_rates)
217
+ self.conv_pre = Conv1d(
218
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
219
+ )
220
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
221
+
222
+ self.ups = nn.ModuleList()
223
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
224
+ self.ups.append(
225
+ weight_norm(
226
+ ConvTranspose1d(
227
+ upsample_initial_channel // (2**i),
228
+ upsample_initial_channel // (2 ** (i + 1)),
229
+ k,
230
+ u,
231
+ padding=(k - u) // 2,
232
+ )
233
+ )
234
+ )
235
+
236
+ self.resblocks = nn.ModuleList()
237
+ for i in range(len(self.ups)):
238
+ ch = upsample_initial_channel // (2 ** (i + 1))
239
+ for j, (k, d) in enumerate(
240
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
241
+ ):
242
+ self.resblocks.append(resblock(ch, k, d))
243
+
244
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
245
+ self.ups.apply(init_weights)
246
+
247
+ if gin_channels != 0:
248
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
249
+
250
+ def forward(self, x, g=None):
251
+ x = self.conv_pre(x)
252
+ if g is not None:
253
+ x = x + self.cond(g)
254
+
255
+ for i in range(self.num_upsamples):
256
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
257
+ x = self.ups[i](x)
258
+ xs = None
259
+ for j in range(self.num_kernels):
260
+ if xs is None:
261
+ xs = self.resblocks[i * self.num_kernels + j](x)
262
+ else:
263
+ xs += self.resblocks[i * self.num_kernels + j](x)
264
+ x = xs / self.num_kernels
265
+ x = F.leaky_relu(x)
266
+ x = self.conv_post(x)
267
+ x = torch.tanh(x)
268
+
269
+ return x
270
+
271
+ def remove_weight_norm(self):
272
+ for l in self.ups:
273
+ remove_weight_norm(l)
274
+ for l in self.resblocks:
275
+ l.remove_weight_norm()
276
+
277
+
278
+ class SineGen(torch.nn.Module):
279
+ """Definition of sine generator
280
+ SineGen(samp_rate, harmonic_num = 0,
281
+ sine_amp = 0.1, noise_std = 0.003,
282
+ voiced_threshold = 0,
283
+ flag_for_pulse=False)
284
+ samp_rate: sampling rate in Hz
285
+ harmonic_num: number of harmonic overtones (default 0)
286
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
287
+ noise_std: std of Gaussian noise (default 0.003)
288
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
289
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
290
+ Note: when flag_for_pulse is True, the first time step of a voiced
291
+ segment is always sin(np.pi) or cos(0)
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ samp_rate,
297
+ harmonic_num=0,
298
+ sine_amp=0.1,
299
+ noise_std=0.003,
300
+ voiced_threshold=0,
301
+ flag_for_pulse=False,
302
+ ):
303
+ super(SineGen, self).__init__()
304
+ self.sine_amp = sine_amp
305
+ self.noise_std = noise_std
306
+ self.harmonic_num = harmonic_num
307
+ self.dim = self.harmonic_num + 1
308
+ self.sampling_rate = samp_rate
309
+ self.voiced_threshold = voiced_threshold
310
+
311
+ def _f02uv(self, f0):
312
+ # generate uv signal
313
+ uv = torch.ones_like(f0)
314
+ uv = uv * (f0 > self.voiced_threshold)
315
+ return uv
316
+
317
+ def forward(self, f0, upp):
318
+ """sine_tensor, uv = forward(f0)
319
+ input F0: tensor(batchsize=1, length, dim=1)
320
+ f0 for unvoiced steps should be 0
321
+ output sine_tensor: tensor(batchsize=1, length, dim)
322
+ output uv: tensor(batchsize=1, length, 1)
323
+ """
324
+ with torch.no_grad():
325
+ f0 = f0[:, None].transpose(1, 2)
326
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
327
+ # fundamental component
328
+ f0_buf[:, :, 0] = f0[:, :, 0]
329
+ for idx in np.arange(self.harmonic_num):
330
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
331
+ idx + 2
332
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
333
+ rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
334
+ rand_ini = torch.rand(
335
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
336
+ )
337
+ rand_ini[:, 0] = 0
338
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
339
+ tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
340
+ tmp_over_one *= upp
341
+ tmp_over_one = F.interpolate(
342
+ tmp_over_one.transpose(2, 1),
343
+ scale_factor=upp,
344
+ mode="linear",
345
+ align_corners=True,
346
+ ).transpose(2, 1)
347
+ rad_values = F.interpolate(
348
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
349
+ ).transpose(
350
+ 2, 1
351
+ ) #######
352
+ tmp_over_one %= 1
353
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
354
+ cumsum_shift = torch.zeros_like(rad_values)
355
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
356
+ sine_waves = torch.sin(
357
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
358
+ )
359
+ sine_waves = sine_waves * self.sine_amp
360
+ uv = self._f02uv(f0)
361
+ uv = F.interpolate(
362
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
363
+ ).transpose(2, 1)
364
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
365
+ noise = noise_amp * torch.randn_like(sine_waves)
366
+ sine_waves = sine_waves * uv + noise
367
+ return sine_waves, uv, noise
368
+
369
+
370
+ class SourceModuleHnNSF(torch.nn.Module):
371
+ """SourceModule for hn-nsf
372
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
373
+ add_noise_std=0.003, voiced_threshod=0)
374
+ sampling_rate: sampling_rate in Hz
375
+ harmonic_num: number of harmonic above F0 (default: 0)
376
+ sine_amp: amplitude of sine source signal (default: 0.1)
377
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
378
+ note that amplitude of noise in unvoiced is decided
379
+ by sine_amp
380
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
381
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
382
+ F0_sampled (batchsize, length, 1)
383
+ Sine_source (batchsize, length, 1)
384
+ noise_source (batchsize, length 1)
385
+ uv (batchsize, length, 1)
386
+ """
387
+
388
+ def __init__(
389
+ self,
390
+ sampling_rate,
391
+ harmonic_num=0,
392
+ sine_amp=0.1,
393
+ add_noise_std=0.003,
394
+ voiced_threshod=0,
395
+ is_half=True,
396
+ ):
397
+ super(SourceModuleHnNSF, self).__init__()
398
+
399
+ self.sine_amp = sine_amp
400
+ self.noise_std = add_noise_std
401
+ self.is_half = is_half
402
+ # to produce sine waveforms
403
+ self.l_sin_gen = SineGen(
404
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
405
+ )
406
+
407
+ # to merge source harmonics into a single excitation
408
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
409
+ self.l_tanh = torch.nn.Tanh()
410
+
411
+ def forward(self, x, upp=None):
412
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
413
+ if self.is_half:
414
+ sine_wavs = sine_wavs.half()
415
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
416
+ return sine_merge, None, None # noise, uv
417
+
418
+
419
+ class GeneratorNSF(torch.nn.Module):
420
+ def __init__(
421
+ self,
422
+ initial_channel,
423
+ resblock,
424
+ resblock_kernel_sizes,
425
+ resblock_dilation_sizes,
426
+ upsample_rates,
427
+ upsample_initial_channel,
428
+ upsample_kernel_sizes,
429
+ gin_channels,
430
+ sr,
431
+ is_half=False,
432
+ ):
433
+ super(GeneratorNSF, self).__init__()
434
+ self.num_kernels = len(resblock_kernel_sizes)
435
+ self.num_upsamples = len(upsample_rates)
436
+
437
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
438
+ self.m_source = SourceModuleHnNSF(
439
+ sampling_rate=sr, harmonic_num=0, is_half=is_half
440
+ )
441
+ self.noise_convs = nn.ModuleList()
442
+ self.conv_pre = Conv1d(
443
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
444
+ )
445
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
446
+
447
+ self.ups = nn.ModuleList()
448
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
449
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
450
+ self.ups.append(
451
+ weight_norm(
452
+ ConvTranspose1d(
453
+ upsample_initial_channel // (2**i),
454
+ upsample_initial_channel // (2 ** (i + 1)),
455
+ k,
456
+ u,
457
+ padding=(k - u) // 2,
458
+ )
459
+ )
460
+ )
461
+ if i + 1 < len(upsample_rates):
462
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
463
+ self.noise_convs.append(
464
+ Conv1d(
465
+ 1,
466
+ c_cur,
467
+ kernel_size=stride_f0 * 2,
468
+ stride=stride_f0,
469
+ padding=stride_f0 // 2,
470
+ )
471
+ )
472
+ else:
473
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
474
+
475
+ self.resblocks = nn.ModuleList()
476
+ for i in range(len(self.ups)):
477
+ ch = upsample_initial_channel // (2 ** (i + 1))
478
+ for j, (k, d) in enumerate(
479
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
480
+ ):
481
+ self.resblocks.append(resblock(ch, k, d))
482
+
483
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
484
+ self.ups.apply(init_weights)
485
+
486
+ if gin_channels != 0:
487
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
488
+
489
+ self.upp = np.prod(upsample_rates)
490
+
491
+ def forward(self, x, f0, g=None):
492
+ har_source, noi_source, uv = self.m_source(f0, self.upp)
493
+ har_source = har_source.transpose(1, 2)
494
+ x = self.conv_pre(x)
495
+ if g is not None:
496
+ x = x + self.cond(g)
497
+
498
+ for i in range(self.num_upsamples):
499
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
500
+ x = self.ups[i](x)
501
+ x_source = self.noise_convs[i](har_source)
502
+ x = x + x_source
503
+ xs = None
504
+ for j in range(self.num_kernels):
505
+ if xs is None:
506
+ xs = self.resblocks[i * self.num_kernels + j](x)
507
+ else:
508
+ xs += self.resblocks[i * self.num_kernels + j](x)
509
+ x = xs / self.num_kernels
510
+ x = F.leaky_relu(x)
511
+ x = self.conv_post(x)
512
+ x = torch.tanh(x)
513
+ return x
514
+
515
+ def remove_weight_norm(self):
516
+ for l in self.ups:
517
+ remove_weight_norm(l)
518
+ for l in self.resblocks:
519
+ l.remove_weight_norm()
520
+
521
+
522
+ sr2sr = {
523
+ "32k": 32000,
524
+ "40k": 40000,
525
+ "48k": 48000,
526
+ }
527
+
528
+
529
+ class SynthesizerTrnMs256NSFsid(nn.Module):
530
+ def __init__(
531
+ self,
532
+ spec_channels,
533
+ segment_size,
534
+ inter_channels,
535
+ hidden_channels,
536
+ filter_channels,
537
+ n_heads,
538
+ n_layers,
539
+ kernel_size,
540
+ p_dropout,
541
+ resblock,
542
+ resblock_kernel_sizes,
543
+ resblock_dilation_sizes,
544
+ upsample_rates,
545
+ upsample_initial_channel,
546
+ upsample_kernel_sizes,
547
+ spk_embed_dim,
548
+ gin_channels,
549
+ sr,
550
+ **kwargs
551
+ ):
552
+ super().__init__()
553
+ if type(sr) == type("strr"):
554
+ sr = sr2sr[sr]
555
+ self.spec_channels = spec_channels
556
+ self.inter_channels = inter_channels
557
+ self.hidden_channels = hidden_channels
558
+ self.filter_channels = filter_channels
559
+ self.n_heads = n_heads
560
+ self.n_layers = n_layers
561
+ self.kernel_size = kernel_size
562
+ self.p_dropout = p_dropout
563
+ self.resblock = resblock
564
+ self.resblock_kernel_sizes = resblock_kernel_sizes
565
+ self.resblock_dilation_sizes = resblock_dilation_sizes
566
+ self.upsample_rates = upsample_rates
567
+ self.upsample_initial_channel = upsample_initial_channel
568
+ self.upsample_kernel_sizes = upsample_kernel_sizes
569
+ self.segment_size = segment_size
570
+ self.gin_channels = gin_channels
571
+ # self.hop_length = hop_length#
572
+ self.spk_embed_dim = spk_embed_dim
573
+ self.enc_p = TextEncoder256(
574
+ inter_channels,
575
+ hidden_channels,
576
+ filter_channels,
577
+ n_heads,
578
+ n_layers,
579
+ kernel_size,
580
+ p_dropout,
581
+ )
582
+ self.dec = GeneratorNSF(
583
+ inter_channels,
584
+ resblock,
585
+ resblock_kernel_sizes,
586
+ resblock_dilation_sizes,
587
+ upsample_rates,
588
+ upsample_initial_channel,
589
+ upsample_kernel_sizes,
590
+ gin_channels=gin_channels,
591
+ sr=sr,
592
+ is_half=kwargs["is_half"],
593
+ )
594
+ self.enc_q = PosteriorEncoder(
595
+ spec_channels,
596
+ inter_channels,
597
+ hidden_channels,
598
+ 5,
599
+ 1,
600
+ 16,
601
+ gin_channels=gin_channels,
602
+ )
603
+ self.flow = ResidualCouplingBlock(
604
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
605
+ )
606
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
607
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
608
+
609
+ def remove_weight_norm(self):
610
+ self.dec.remove_weight_norm()
611
+ self.flow.remove_weight_norm()
612
+ self.enc_q.remove_weight_norm()
613
+
614
+ def forward(
615
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
616
+ ): # 这里ds是id,[bs,1]
617
+ # print(1,pitch.shape)#[bs,t]
618
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
619
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
620
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
621
+ z_p = self.flow(z, y_mask, g=g)
622
+ z_slice, ids_slice = commons.rand_slice_segments(
623
+ z, y_lengths, self.segment_size
624
+ )
625
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
626
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
627
+ # print(-2,pitchf.shape,z_slice.shape)
628
+ o = self.dec(z_slice, pitchf, g=g)
629
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
630
+
631
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
632
+ g = self.emb_g(sid).unsqueeze(-1)
633
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
634
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
635
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
636
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
637
+ return o, x_mask, (z, z_p, m_p, logs_p)
638
+ class SynthesizerTrnMs768NSFsid(nn.Module):
639
+ def __init__(
640
+ self,
641
+ spec_channels,
642
+ segment_size,
643
+ inter_channels,
644
+ hidden_channels,
645
+ filter_channels,
646
+ n_heads,
647
+ n_layers,
648
+ kernel_size,
649
+ p_dropout,
650
+ resblock,
651
+ resblock_kernel_sizes,
652
+ resblock_dilation_sizes,
653
+ upsample_rates,
654
+ upsample_initial_channel,
655
+ upsample_kernel_sizes,
656
+ spk_embed_dim,
657
+ gin_channels,
658
+ sr,
659
+ **kwargs
660
+ ):
661
+ super().__init__()
662
+ if type(sr) == type("strr"):
663
+ sr = sr2sr[sr]
664
+ self.spec_channels = spec_channels
665
+ self.inter_channels = inter_channels
666
+ self.hidden_channels = hidden_channels
667
+ self.filter_channels = filter_channels
668
+ self.n_heads = n_heads
669
+ self.n_layers = n_layers
670
+ self.kernel_size = kernel_size
671
+ self.p_dropout = p_dropout
672
+ self.resblock = resblock
673
+ self.resblock_kernel_sizes = resblock_kernel_sizes
674
+ self.resblock_dilation_sizes = resblock_dilation_sizes
675
+ self.upsample_rates = upsample_rates
676
+ self.upsample_initial_channel = upsample_initial_channel
677
+ self.upsample_kernel_sizes = upsample_kernel_sizes
678
+ self.segment_size = segment_size
679
+ self.gin_channels = gin_channels
680
+ # self.hop_length = hop_length#
681
+ self.spk_embed_dim = spk_embed_dim
682
+ self.enc_p = TextEncoder768(
683
+ inter_channels,
684
+ hidden_channels,
685
+ filter_channels,
686
+ n_heads,
687
+ n_layers,
688
+ kernel_size,
689
+ p_dropout,
690
+ )
691
+ self.dec = GeneratorNSF(
692
+ inter_channels,
693
+ resblock,
694
+ resblock_kernel_sizes,
695
+ resblock_dilation_sizes,
696
+ upsample_rates,
697
+ upsample_initial_channel,
698
+ upsample_kernel_sizes,
699
+ gin_channels=gin_channels,
700
+ sr=sr,
701
+ is_half=kwargs["is_half"],
702
+ )
703
+ self.enc_q = PosteriorEncoder(
704
+ spec_channels,
705
+ inter_channels,
706
+ hidden_channels,
707
+ 5,
708
+ 1,
709
+ 16,
710
+ gin_channels=gin_channels,
711
+ )
712
+ self.flow = ResidualCouplingBlock(
713
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
714
+ )
715
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
716
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
717
+
718
+ def remove_weight_norm(self):
719
+ self.dec.remove_weight_norm()
720
+ self.flow.remove_weight_norm()
721
+ self.enc_q.remove_weight_norm()
722
+
723
+ def forward(
724
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
725
+ ): # 这里ds是id,[bs,1]
726
+ # print(1,pitch.shape)#[bs,t]
727
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
728
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
729
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
730
+ z_p = self.flow(z, y_mask, g=g)
731
+ z_slice, ids_slice = commons.rand_slice_segments(
732
+ z, y_lengths, self.segment_size
733
+ )
734
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
735
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
736
+ # print(-2,pitchf.shape,z_slice.shape)
737
+ o = self.dec(z_slice, pitchf, g=g)
738
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
739
+
740
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
741
+ g = self.emb_g(sid).unsqueeze(-1)
742
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
743
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
744
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
745
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
746
+ return o, x_mask, (z, z_p, m_p, logs_p)
747
+
748
+
749
+ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
750
+ def __init__(
751
+ self,
752
+ spec_channels,
753
+ segment_size,
754
+ inter_channels,
755
+ hidden_channels,
756
+ filter_channels,
757
+ n_heads,
758
+ n_layers,
759
+ kernel_size,
760
+ p_dropout,
761
+ resblock,
762
+ resblock_kernel_sizes,
763
+ resblock_dilation_sizes,
764
+ upsample_rates,
765
+ upsample_initial_channel,
766
+ upsample_kernel_sizes,
767
+ spk_embed_dim,
768
+ gin_channels,
769
+ sr=None,
770
+ **kwargs
771
+ ):
772
+ super().__init__()
773
+ self.spec_channels = spec_channels
774
+ self.inter_channels = inter_channels
775
+ self.hidden_channels = hidden_channels
776
+ self.filter_channels = filter_channels
777
+ self.n_heads = n_heads
778
+ self.n_layers = n_layers
779
+ self.kernel_size = kernel_size
780
+ self.p_dropout = p_dropout
781
+ self.resblock = resblock
782
+ self.resblock_kernel_sizes = resblock_kernel_sizes
783
+ self.resblock_dilation_sizes = resblock_dilation_sizes
784
+ self.upsample_rates = upsample_rates
785
+ self.upsample_initial_channel = upsample_initial_channel
786
+ self.upsample_kernel_sizes = upsample_kernel_sizes
787
+ self.segment_size = segment_size
788
+ self.gin_channels = gin_channels
789
+ # self.hop_length = hop_length#
790
+ self.spk_embed_dim = spk_embed_dim
791
+ self.enc_p = TextEncoder256(
792
+ inter_channels,
793
+ hidden_channels,
794
+ filter_channels,
795
+ n_heads,
796
+ n_layers,
797
+ kernel_size,
798
+ p_dropout,
799
+ f0=False,
800
+ )
801
+ self.dec = Generator(
802
+ inter_channels,
803
+ resblock,
804
+ resblock_kernel_sizes,
805
+ resblock_dilation_sizes,
806
+ upsample_rates,
807
+ upsample_initial_channel,
808
+ upsample_kernel_sizes,
809
+ gin_channels=gin_channels,
810
+ )
811
+ self.enc_q = PosteriorEncoder(
812
+ spec_channels,
813
+ inter_channels,
814
+ hidden_channels,
815
+ 5,
816
+ 1,
817
+ 16,
818
+ gin_channels=gin_channels,
819
+ )
820
+ self.flow = ResidualCouplingBlock(
821
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
822
+ )
823
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
824
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
825
+
826
+ def remove_weight_norm(self):
827
+ self.dec.remove_weight_norm()
828
+ self.flow.remove_weight_norm()
829
+ self.enc_q.remove_weight_norm()
830
+
831
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
832
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
833
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
834
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
835
+ z_p = self.flow(z, y_mask, g=g)
836
+ z_slice, ids_slice = commons.rand_slice_segments(
837
+ z, y_lengths, self.segment_size
838
+ )
839
+ o = self.dec(z_slice, g=g)
840
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
841
+
842
+ def infer(self, phone, phone_lengths, sid, max_len=None):
843
+ g = self.emb_g(sid).unsqueeze(-1)
844
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
845
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
846
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
847
+ o = self.dec((z * x_mask)[:, :, :max_len], g=g)
848
+ return o, x_mask, (z, z_p, m_p, logs_p)
849
+ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
850
+ def __init__(
851
+ self,
852
+ spec_channels,
853
+ segment_size,
854
+ inter_channels,
855
+ hidden_channels,
856
+ filter_channels,
857
+ n_heads,
858
+ n_layers,
859
+ kernel_size,
860
+ p_dropout,
861
+ resblock,
862
+ resblock_kernel_sizes,
863
+ resblock_dilation_sizes,
864
+ upsample_rates,
865
+ upsample_initial_channel,
866
+ upsample_kernel_sizes,
867
+ spk_embed_dim,
868
+ gin_channels,
869
+ sr=None,
870
+ **kwargs
871
+ ):
872
+ super().__init__()
873
+ self.spec_channels = spec_channels
874
+ self.inter_channels = inter_channels
875
+ self.hidden_channels = hidden_channels
876
+ self.filter_channels = filter_channels
877
+ self.n_heads = n_heads
878
+ self.n_layers = n_layers
879
+ self.kernel_size = kernel_size
880
+ self.p_dropout = p_dropout
881
+ self.resblock = resblock
882
+ self.resblock_kernel_sizes = resblock_kernel_sizes
883
+ self.resblock_dilation_sizes = resblock_dilation_sizes
884
+ self.upsample_rates = upsample_rates
885
+ self.upsample_initial_channel = upsample_initial_channel
886
+ self.upsample_kernel_sizes = upsample_kernel_sizes
887
+ self.segment_size = segment_size
888
+ self.gin_channels = gin_channels
889
+ # self.hop_length = hop_length#
890
+ self.spk_embed_dim = spk_embed_dim
891
+ self.enc_p = TextEncoder768(
892
+ inter_channels,
893
+ hidden_channels,
894
+ filter_channels,
895
+ n_heads,
896
+ n_layers,
897
+ kernel_size,
898
+ p_dropout,
899
+ f0=False,
900
+ )
901
+ self.dec = Generator(
902
+ inter_channels,
903
+ resblock,
904
+ resblock_kernel_sizes,
905
+ resblock_dilation_sizes,
906
+ upsample_rates,
907
+ upsample_initial_channel,
908
+ upsample_kernel_sizes,
909
+ gin_channels=gin_channels,
910
+ )
911
+ self.enc_q = PosteriorEncoder(
912
+ spec_channels,
913
+ inter_channels,
914
+ hidden_channels,
915
+ 5,
916
+ 1,
917
+ 16,
918
+ gin_channels=gin_channels,
919
+ )
920
+ self.flow = ResidualCouplingBlock(
921
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
922
+ )
923
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
924
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
925
+
926
+ def remove_weight_norm(self):
927
+ self.dec.remove_weight_norm()
928
+ self.flow.remove_weight_norm()
929
+ self.enc_q.remove_weight_norm()
930
+
931
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
932
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
933
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
934
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
935
+ z_p = self.flow(z, y_mask, g=g)
936
+ z_slice, ids_slice = commons.rand_slice_segments(
937
+ z, y_lengths, self.segment_size
938
+ )
939
+ o = self.dec(z_slice, g=g)
940
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
941
+
942
+ def infer(self, phone, phone_lengths, sid, max_len=None):
943
+ g = self.emb_g(sid).unsqueeze(-1)
944
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
945
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
946
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
947
+ o = self.dec((z * x_mask)[:, :, :max_len], g=g)
948
+ return o, x_mask, (z, z_p, m_p, logs_p)
949
+
950
+
951
+ class MultiPeriodDiscriminator(torch.nn.Module):
952
+ def __init__(self, use_spectral_norm=False):
953
+ super(MultiPeriodDiscriminator, self).__init__()
954
+ periods = [2, 3, 5, 7, 11, 17]
955
+ # periods = [3, 5, 7, 11, 17, 23, 37]
956
+
957
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
958
+ discs = discs + [
959
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
960
+ ]
961
+ self.discriminators = nn.ModuleList(discs)
962
+
963
+ def forward(self, y, y_hat):
964
+ y_d_rs = [] #
965
+ y_d_gs = []
966
+ fmap_rs = []
967
+ fmap_gs = []
968
+ for i, d in enumerate(self.discriminators):
969
+ y_d_r, fmap_r = d(y)
970
+ y_d_g, fmap_g = d(y_hat)
971
+ # for j in range(len(fmap_r)):
972
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
973
+ y_d_rs.append(y_d_r)
974
+ y_d_gs.append(y_d_g)
975
+ fmap_rs.append(fmap_r)
976
+ fmap_gs.append(fmap_g)
977
+
978
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
979
+
980
+ class MultiPeriodDiscriminatorV2(torch.nn.Module):
981
+ def __init__(self, use_spectral_norm=False):
982
+ super(MultiPeriodDiscriminatorV2, self).__init__()
983
+ # periods = [2, 3, 5, 7, 11, 17]
984
+ periods = [2,3, 5, 7, 11, 17, 23, 37]
985
+
986
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
987
+ discs = discs + [
988
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
989
+ ]
990
+ self.discriminators = nn.ModuleList(discs)
991
+
992
+ def forward(self, y, y_hat):
993
+ y_d_rs = [] #
994
+ y_d_gs = []
995
+ fmap_rs = []
996
+ fmap_gs = []
997
+ for i, d in enumerate(self.discriminators):
998
+ y_d_r, fmap_r = d(y)
999
+ y_d_g, fmap_g = d(y_hat)
1000
+ # for j in range(len(fmap_r)):
1001
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
1002
+ y_d_rs.append(y_d_r)
1003
+ y_d_gs.append(y_d_g)
1004
+ fmap_rs.append(fmap_r)
1005
+ fmap_gs.append(fmap_g)
1006
+
1007
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1008
+
1009
+
1010
+ class DiscriminatorS(torch.nn.Module):
1011
+ def __init__(self, use_spectral_norm=False):
1012
+ super(DiscriminatorS, self).__init__()
1013
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1014
+ self.convs = nn.ModuleList(
1015
+ [
1016
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
1017
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
1018
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
1019
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
1020
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
1021
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
1022
+ ]
1023
+ )
1024
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
1025
+
1026
+ def forward(self, x):
1027
+ fmap = []
1028
+
1029
+ for l in self.convs:
1030
+ x = l(x)
1031
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1032
+ fmap.append(x)
1033
+ x = self.conv_post(x)
1034
+ fmap.append(x)
1035
+ x = torch.flatten(x, 1, -1)
1036
+
1037
+ return x, fmap
1038
+
1039
+
1040
+ class DiscriminatorP(torch.nn.Module):
1041
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
1042
+ super(DiscriminatorP, self).__init__()
1043
+ self.period = period
1044
+ self.use_spectral_norm = use_spectral_norm
1045
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1046
+ self.convs = nn.ModuleList(
1047
+ [
1048
+ norm_f(
1049
+ Conv2d(
1050
+ 1,
1051
+ 32,
1052
+ (kernel_size, 1),
1053
+ (stride, 1),
1054
+ padding=(get_padding(kernel_size, 1), 0),
1055
+ )
1056
+ ),
1057
+ norm_f(
1058
+ Conv2d(
1059
+ 32,
1060
+ 128,
1061
+ (kernel_size, 1),
1062
+ (stride, 1),
1063
+ padding=(get_padding(kernel_size, 1), 0),
1064
+ )
1065
+ ),
1066
+ norm_f(
1067
+ Conv2d(
1068
+ 128,
1069
+ 512,
1070
+ (kernel_size, 1),
1071
+ (stride, 1),
1072
+ padding=(get_padding(kernel_size, 1), 0),
1073
+ )
1074
+ ),
1075
+ norm_f(
1076
+ Conv2d(
1077
+ 512,
1078
+ 1024,
1079
+ (kernel_size, 1),
1080
+ (stride, 1),
1081
+ padding=(get_padding(kernel_size, 1), 0),
1082
+ )
1083
+ ),
1084
+ norm_f(
1085
+ Conv2d(
1086
+ 1024,
1087
+ 1024,
1088
+ (kernel_size, 1),
1089
+ 1,
1090
+ padding=(get_padding(kernel_size, 1), 0),
1091
+ )
1092
+ ),
1093
+ ]
1094
+ )
1095
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1096
+
1097
+ def forward(self, x):
1098
+ fmap = []
1099
+
1100
+ # 1d to 2d
1101
+ b, c, t = x.shape
1102
+ if t % self.period != 0: # pad first
1103
+ n_pad = self.period - (t % self.period)
1104
+ x = F.pad(x, (0, n_pad), "reflect")
1105
+ t = t + n_pad
1106
+ x = x.view(b, c, t // self.period, self.period)
1107
+
1108
+ for l in self.convs:
1109
+ x = l(x)
1110
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1111
+ fmap.append(x)
1112
+ x = self.conv_post(x)
1113
+ fmap.append(x)
1114
+ x = torch.flatten(x, 1, -1)
1115
+
1116
+ return x, fmap
infer_pack/models_onnx.py CHANGED
@@ -1,764 +1,760 @@
1
- import math,pdb,os
2
- from time import time as ttime
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from infer_pack import modules
7
- from infer_pack import attentions
8
- from infer_pack import commons
9
- from infer_pack.commons import init_weights, get_padding
10
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
- from infer_pack.commons import init_weights
13
- import numpy as np
14
- from infer_pack import commons
15
- class TextEncoder256(nn.Module):
16
- def __init__(
17
- self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True ):
18
- super().__init__()
19
- self.out_channels = out_channels
20
- self.hidden_channels = hidden_channels
21
- self.filter_channels = filter_channels
22
- self.n_heads = n_heads
23
- self.n_layers = n_layers
24
- self.kernel_size = kernel_size
25
- self.p_dropout = p_dropout
26
- self.emb_phone = nn.Linear(256, hidden_channels)
27
- self.lrelu=nn.LeakyReLU(0.1,inplace=True)
28
- if(f0==True):
29
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
30
- self.encoder = attentions.Encoder(
31
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
32
- )
33
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
34
-
35
- def forward(self, phone, pitch, lengths):
36
- if(pitch==None):
37
- x = self.emb_phone(phone)
38
- else:
39
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
40
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
41
- x=self.lrelu(x)
42
- x = torch.transpose(x, 1, -1) # [b, h, t]
43
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
44
- x.dtype
45
- )
46
- x = self.encoder(x * x_mask, x_mask)
47
- stats = self.proj(x) * x_mask
48
-
49
- m, logs = torch.split(stats, self.out_channels, dim=1)
50
- return m, logs, x_mask
51
- class TextEncoder256Sim(nn.Module):
52
- def __init__( self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True):
53
- super().__init__()
54
- self.out_channels = out_channels
55
- self.hidden_channels = hidden_channels
56
- self.filter_channels = filter_channels
57
- self.n_heads = n_heads
58
- self.n_layers = n_layers
59
- self.kernel_size = kernel_size
60
- self.p_dropout = p_dropout
61
- self.emb_phone = nn.Linear(256, hidden_channels)
62
- self.lrelu=nn.LeakyReLU(0.1,inplace=True)
63
- if(f0==True):
64
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
65
- self.encoder = attentions.Encoder(
66
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
67
- )
68
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
69
-
70
- def forward(self, phone, pitch, lengths):
71
- if(pitch==None):
72
- x = self.emb_phone(phone)
73
- else:
74
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
75
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
76
- x=self.lrelu(x)
77
- x = torch.transpose(x, 1, -1) # [b, h, t]
78
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
79
- x = self.encoder(x * x_mask, x_mask)
80
- x = self.proj(x) * x_mask
81
- return x,x_mask
82
- class ResidualCouplingBlock(nn.Module):
83
- def __init__(
84
- self,
85
- channels,
86
- hidden_channels,
87
- kernel_size,
88
- dilation_rate,
89
- n_layers,
90
- n_flows=4,
91
- gin_channels=0,
92
- ):
93
- super().__init__()
94
- self.channels = channels
95
- self.hidden_channels = hidden_channels
96
- self.kernel_size = kernel_size
97
- self.dilation_rate = dilation_rate
98
- self.n_layers = n_layers
99
- self.n_flows = n_flows
100
- self.gin_channels = gin_channels
101
-
102
- self.flows = nn.ModuleList()
103
- for i in range(n_flows):
104
- self.flows.append(
105
- modules.ResidualCouplingLayer(
106
- channels,
107
- hidden_channels,
108
- kernel_size,
109
- dilation_rate,
110
- n_layers,
111
- gin_channels=gin_channels,
112
- mean_only=True,
113
- )
114
- )
115
- self.flows.append(modules.Flip())
116
-
117
- def forward(self, x, x_mask, g=None, reverse=False):
118
- if not reverse:
119
- for flow in self.flows:
120
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
121
- else:
122
- for flow in reversed(self.flows):
123
- x = flow(x, x_mask, g=g, reverse=reverse)
124
- return x
125
-
126
- def remove_weight_norm(self):
127
- for i in range(self.n_flows):
128
- self.flows[i * 2].remove_weight_norm()
129
- class PosteriorEncoder(nn.Module):
130
- def __init__(
131
- self,
132
- in_channels,
133
- out_channels,
134
- hidden_channels,
135
- kernel_size,
136
- dilation_rate,
137
- n_layers,
138
- gin_channels=0,
139
- ):
140
- super().__init__()
141
- self.in_channels = in_channels
142
- self.out_channels = out_channels
143
- self.hidden_channels = hidden_channels
144
- self.kernel_size = kernel_size
145
- self.dilation_rate = dilation_rate
146
- self.n_layers = n_layers
147
- self.gin_channels = gin_channels
148
-
149
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
150
- self.enc = modules.WN(
151
- hidden_channels,
152
- kernel_size,
153
- dilation_rate,
154
- n_layers,
155
- gin_channels=gin_channels,
156
- )
157
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
158
-
159
- def forward(self, x, x_lengths, g=None):
160
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
161
- x.dtype
162
- )
163
- x = self.pre(x) * x_mask
164
- x = self.enc(x, x_mask, g=g)
165
- stats = self.proj(x) * x_mask
166
- m, logs = torch.split(stats, self.out_channels, dim=1)
167
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
168
- return z, m, logs, x_mask
169
-
170
- def remove_weight_norm(self):
171
- self.enc.remove_weight_norm()
172
- class Generator(torch.nn.Module):
173
- def __init__(
174
- self,
175
- initial_channel,
176
- resblock,
177
- resblock_kernel_sizes,
178
- resblock_dilation_sizes,
179
- upsample_rates,
180
- upsample_initial_channel,
181
- upsample_kernel_sizes,
182
- gin_channels=0,
183
- ):
184
- super(Generator, self).__init__()
185
- self.num_kernels = len(resblock_kernel_sizes)
186
- self.num_upsamples = len(upsample_rates)
187
- self.conv_pre = Conv1d(
188
- initial_channel, upsample_initial_channel, 7, 1, padding=3
189
- )
190
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
191
-
192
- self.ups = nn.ModuleList()
193
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
194
- self.ups.append(
195
- weight_norm(
196
- ConvTranspose1d(
197
- upsample_initial_channel // (2**i),
198
- upsample_initial_channel // (2 ** (i + 1)),
199
- k,
200
- u,
201
- padding=(k - u) // 2,
202
- )
203
- )
204
- )
205
-
206
- self.resblocks = nn.ModuleList()
207
- for i in range(len(self.ups)):
208
- ch = upsample_initial_channel // (2 ** (i + 1))
209
- for j, (k, d) in enumerate(
210
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
211
- ):
212
- self.resblocks.append(resblock(ch, k, d))
213
-
214
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
215
- self.ups.apply(init_weights)
216
-
217
- if gin_channels != 0:
218
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
219
-
220
- def forward(self, x, g=None):
221
- x = self.conv_pre(x)
222
- if g is not None:
223
- x = x + self.cond(g)
224
-
225
- for i in range(self.num_upsamples):
226
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
227
- x = self.ups[i](x)
228
- xs = None
229
- for j in range(self.num_kernels):
230
- if xs is None:
231
- xs = self.resblocks[i * self.num_kernels + j](x)
232
- else:
233
- xs += self.resblocks[i * self.num_kernels + j](x)
234
- x = xs / self.num_kernels
235
- x = F.leaky_relu(x)
236
- x = self.conv_post(x)
237
- x = torch.tanh(x)
238
-
239
- return x
240
-
241
- def remove_weight_norm(self):
242
- for l in self.ups:
243
- remove_weight_norm(l)
244
- for l in self.resblocks:
245
- l.remove_weight_norm()
246
- class SineGen(torch.nn.Module):
247
- """ Definition of sine generator
248
- SineGen(samp_rate, harmonic_num = 0,
249
- sine_amp = 0.1, noise_std = 0.003,
250
- voiced_threshold = 0,
251
- flag_for_pulse=False)
252
- samp_rate: sampling rate in Hz
253
- harmonic_num: number of harmonic overtones (default 0)
254
- sine_amp: amplitude of sine-wavefrom (default 0.1)
255
- noise_std: std of Gaussian noise (default 0.003)
256
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
257
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
258
- Note: when flag_for_pulse is True, the first time step of a voiced
259
- segment is always sin(np.pi) or cos(0)
260
- """
261
-
262
- def __init__(self, samp_rate, harmonic_num=0,
263
- sine_amp=0.1, noise_std=0.003,
264
- voiced_threshold=0,
265
- flag_for_pulse=False):
266
- super(SineGen, self).__init__()
267
- self.sine_amp = sine_amp
268
- self.noise_std = noise_std
269
- self.harmonic_num = harmonic_num
270
- self.dim = self.harmonic_num + 1
271
- self.sampling_rate = samp_rate
272
- self.voiced_threshold = voiced_threshold
273
-
274
- def _f02uv(self, f0):
275
- # generate uv signal
276
- uv = torch.ones_like(f0)
277
- uv = uv * (f0 > self.voiced_threshold)
278
- return uv
279
-
280
- def forward(self, f0,upp):
281
- """ sine_tensor, uv = forward(f0)
282
- input F0: tensor(batchsize=1, length, dim=1)
283
- f0 for unvoiced steps should be 0
284
- output sine_tensor: tensor(batchsize=1, length, dim)
285
- output uv: tensor(batchsize=1, length, 1)
286
- """
287
- with torch.no_grad():
288
- f0 = f0[:, None].transpose(1, 2)
289
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,device=f0.device)
290
- # fundamental component
291
- f0_buf[:, :, 0] = f0[:, :, 0]
292
- for idx in np.arange(self.harmonic_num):f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)# idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
293
- rad_values = (f0_buf / self.sampling_rate) % 1###%1意味着n_har的乘积无法后处理优化
294
- rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device)
295
- rand_ini[:, 0] = 0
296
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
297
- tmp_over_one = torch.cumsum(rad_values, 1)# % 1 #####%1意味着后面的cumsum无法再优化
298
- tmp_over_one*=upp
299
- tmp_over_one=F.interpolate(tmp_over_one.transpose(2, 1), scale_factor=upp, mode='linear', align_corners=True).transpose(2, 1)
300
- rad_values=F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)#######
301
- tmp_over_one%=1
302
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
303
- cumsum_shift = torch.zeros_like(rad_values)
304
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
305
- sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
306
- sine_waves = sine_waves * self.sine_amp
307
- uv = self._f02uv(f0)
308
- uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
309
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
310
- noise = noise_amp * torch.randn_like(sine_waves)
311
- sine_waves = sine_waves * uv + noise
312
- return sine_waves, uv, noise
313
- class SourceModuleHnNSF(torch.nn.Module):
314
- """ SourceModule for hn-nsf
315
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
316
- add_noise_std=0.003, voiced_threshod=0)
317
- sampling_rate: sampling_rate in Hz
318
- harmonic_num: number of harmonic above F0 (default: 0)
319
- sine_amp: amplitude of sine source signal (default: 0.1)
320
- add_noise_std: std of additive Gaussian noise (default: 0.003)
321
- note that amplitude of noise in unvoiced is decided
322
- by sine_amp
323
- voiced_threshold: threhold to set U/V given F0 (default: 0)
324
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
325
- F0_sampled (batchsize, length, 1)
326
- Sine_source (batchsize, length, 1)
327
- noise_source (batchsize, length 1)
328
- uv (batchsize, length, 1)
329
- """
330
-
331
- def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
332
- add_noise_std=0.003, voiced_threshod=0,is_half=True):
333
- super(SourceModuleHnNSF, self).__init__()
334
-
335
- self.sine_amp = sine_amp
336
- self.noise_std = add_noise_std
337
- self.is_half=is_half
338
- # to produce sine waveforms
339
- self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
340
- sine_amp, add_noise_std, voiced_threshod)
341
-
342
- # to merge source harmonics into a single excitation
343
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
344
- self.l_tanh = torch.nn.Tanh()
345
-
346
- def forward(self, x,upp=None):
347
- sine_wavs, uv, _ = self.l_sin_gen(x,upp)
348
- if(self.is_half):sine_wavs=sine_wavs.half()
349
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
350
- return sine_merge,None,None# noise, uv
351
- class GeneratorNSF(torch.nn.Module):
352
- def __init__(
353
- self,
354
- initial_channel,
355
- resblock,
356
- resblock_kernel_sizes,
357
- resblock_dilation_sizes,
358
- upsample_rates,
359
- upsample_initial_channel,
360
- upsample_kernel_sizes,
361
- gin_channels,
362
- sr,
363
- is_half=False
364
- ):
365
- super(GeneratorNSF, self).__init__()
366
- self.num_kernels = len(resblock_kernel_sizes)
367
- self.num_upsamples = len(upsample_rates)
368
-
369
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
370
- self.m_source = SourceModuleHnNSF(
371
- sampling_rate=sr,
372
- harmonic_num=0,
373
- is_half=is_half
374
- )
375
- self.noise_convs = nn.ModuleList()
376
- self.conv_pre = Conv1d(
377
- initial_channel, upsample_initial_channel, 7, 1, padding=3
378
- )
379
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
380
-
381
- self.ups = nn.ModuleList()
382
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
383
- c_cur = upsample_initial_channel // (2 ** (i + 1))
384
- self.ups.append(
385
- weight_norm(
386
- ConvTranspose1d(
387
- upsample_initial_channel // (2**i),
388
- upsample_initial_channel // (2 ** (i + 1)),
389
- k,
390
- u,
391
- padding=(k - u) // 2,
392
- )
393
- )
394
- )
395
- if i + 1 < len(upsample_rates):
396
- stride_f0 = np.prod(upsample_rates[i + 1:])
397
- self.noise_convs.append(Conv1d(
398
- 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
399
- else:
400
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
401
-
402
- self.resblocks = nn.ModuleList()
403
- for i in range(len(self.ups)):
404
- ch = upsample_initial_channel // (2 ** (i + 1))
405
- for j, (k, d) in enumerate(
406
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
407
- ):
408
- self.resblocks.append(resblock(ch, k, d))
409
-
410
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
411
- self.ups.apply(init_weights)
412
-
413
- if gin_channels != 0:
414
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
415
-
416
- self.upp=np.prod(upsample_rates)
417
-
418
- def forward(self, x, f0,g=None):
419
- har_source, noi_source, uv = self.m_source(f0,self.upp)
420
- har_source = har_source.transpose(1, 2)
421
- x = self.conv_pre(x)
422
- if g is not None:
423
- x = x + self.cond(g)
424
-
425
- for i in range(self.num_upsamples):
426
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
427
- x = self.ups[i](x)
428
- x_source = self.noise_convs[i](har_source)
429
- x = x + x_source
430
- xs = None
431
- for j in range(self.num_kernels):
432
- if xs is None:
433
- xs = self.resblocks[i * self.num_kernels + j](x)
434
- else:
435
- xs += self.resblocks[i * self.num_kernels + j](x)
436
- x = xs / self.num_kernels
437
- x = F.leaky_relu(x)
438
- x = self.conv_post(x)
439
- x = torch.tanh(x)
440
- return x
441
-
442
- def remove_weight_norm(self):
443
- for l in self.ups:
444
- remove_weight_norm(l)
445
- for l in self.resblocks:
446
- l.remove_weight_norm()
447
- sr2sr={
448
- "32k":32000,
449
- "40k":40000,
450
- "48k":48000,
451
- }
452
- class SynthesizerTrnMs256NSFsid(nn.Module):
453
- def __init__(
454
- self,
455
- spec_channels,
456
- segment_size,
457
- inter_channels,
458
- hidden_channels,
459
- filter_channels,
460
- n_heads,
461
- n_layers,
462
- kernel_size,
463
- p_dropout,
464
- resblock,
465
- resblock_kernel_sizes,
466
- resblock_dilation_sizes,
467
- upsample_rates,
468
- upsample_initial_channel,
469
- upsample_kernel_sizes,
470
- spk_embed_dim,
471
- gin_channels,
472
- sr,
473
- **kwargs
474
- ):
475
-
476
- super().__init__()
477
- if(type(sr)==type("strr")):
478
- sr=sr2sr[sr]
479
- self.spec_channels = spec_channels
480
- self.inter_channels = inter_channels
481
- self.hidden_channels = hidden_channels
482
- self.filter_channels = filter_channels
483
- self.n_heads = n_heads
484
- self.n_layers = n_layers
485
- self.kernel_size = kernel_size
486
- self.p_dropout = p_dropout
487
- self.resblock = resblock
488
- self.resblock_kernel_sizes = resblock_kernel_sizes
489
- self.resblock_dilation_sizes = resblock_dilation_sizes
490
- self.upsample_rates = upsample_rates
491
- self.upsample_initial_channel = upsample_initial_channel
492
- self.upsample_kernel_sizes = upsample_kernel_sizes
493
- self.segment_size = segment_size
494
- self.gin_channels = gin_channels
495
- # self.hop_length = hop_length#
496
- self.spk_embed_dim=spk_embed_dim
497
- self.enc_p = TextEncoder256(
498
- inter_channels,
499
- hidden_channels,
500
- filter_channels,
501
- n_heads,
502
- n_layers,
503
- kernel_size,
504
- p_dropout,
505
- )
506
- self.dec = GeneratorNSF(
507
- inter_channels,
508
- resblock,
509
- resblock_kernel_sizes,
510
- resblock_dilation_sizes,
511
- upsample_rates,
512
- upsample_initial_channel,
513
- upsample_kernel_sizes,
514
- gin_channels=gin_channels, sr=sr, is_half=kwargs["is_half"]
515
- )
516
- self.enc_q = PosteriorEncoder(
517
- spec_channels,
518
- inter_channels,
519
- hidden_channels,
520
- 5,
521
- 1,
522
- 16,
523
- gin_channels=gin_channels,
524
- )
525
- self.flow = ResidualCouplingBlock(
526
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
527
- )
528
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
529
- print("gin_channels:",gin_channels,"self.spk_embed_dim:",self.spk_embed_dim)
530
- def remove_weight_norm(self):
531
- self.dec.remove_weight_norm()
532
- self.flow.remove_weight_norm()
533
- self.enc_q.remove_weight_norm()
534
-
535
- def forward(self, phone, phone_lengths, pitch, nsff0 ,sid, rnd, max_len=None):
536
-
537
- g = self.emb_g(sid).unsqueeze(-1)
538
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
539
- z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
540
- z = self.flow(z_p, x_mask, g=g, reverse=True)
541
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0,g=g)
542
- return o
543
-
544
- class SynthesizerTrnMs256NSFsid_sim(nn.Module):
545
- """
546
- Synthesizer for Training
547
- """
548
-
549
- def __init__(
550
- self,
551
- spec_channels,
552
- segment_size,
553
- inter_channels,
554
- hidden_channels,
555
- filter_channels,
556
- n_heads,
557
- n_layers,
558
- kernel_size,
559
- p_dropout,
560
- resblock,
561
- resblock_kernel_sizes,
562
- resblock_dilation_sizes,
563
- upsample_rates,
564
- upsample_initial_channel,
565
- upsample_kernel_sizes,
566
- spk_embed_dim,
567
- # hop_length,
568
- gin_channels=0,
569
- use_sdp=True,
570
- **kwargs
571
- ):
572
-
573
- super().__init__()
574
- self.spec_channels = spec_channels
575
- self.inter_channels = inter_channels
576
- self.hidden_channels = hidden_channels
577
- self.filter_channels = filter_channels
578
- self.n_heads = n_heads
579
- self.n_layers = n_layers
580
- self.kernel_size = kernel_size
581
- self.p_dropout = p_dropout
582
- self.resblock = resblock
583
- self.resblock_kernel_sizes = resblock_kernel_sizes
584
- self.resblock_dilation_sizes = resblock_dilation_sizes
585
- self.upsample_rates = upsample_rates
586
- self.upsample_initial_channel = upsample_initial_channel
587
- self.upsample_kernel_sizes = upsample_kernel_sizes
588
- self.segment_size = segment_size
589
- self.gin_channels = gin_channels
590
- # self.hop_length = hop_length#
591
- self.spk_embed_dim=spk_embed_dim
592
- self.enc_p = TextEncoder256Sim(
593
- inter_channels,
594
- hidden_channels,
595
- filter_channels,
596
- n_heads,
597
- n_layers,
598
- kernel_size,
599
- p_dropout,
600
- )
601
- self.dec = GeneratorNSF(
602
- inter_channels,
603
- resblock,
604
- resblock_kernel_sizes,
605
- resblock_dilation_sizes,
606
- upsample_rates,
607
- upsample_initial_channel,
608
- upsample_kernel_sizes,
609
- gin_channels=gin_channels,is_half=kwargs["is_half"]
610
- )
611
-
612
- self.flow = ResidualCouplingBlock(
613
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
614
- )
615
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
616
- print("gin_channels:",gin_channels,"self.spk_embed_dim:",self.spk_embed_dim)
617
- def remove_weight_norm(self):
618
- self.dec.remove_weight_norm()
619
- self.flow.remove_weight_norm()
620
- self.enc_q.remove_weight_norm()
621
-
622
- def forward(self, phone, phone_lengths, pitch, pitchf, ds,max_len=None): # y是spec不需要了现在
623
- g = self.emb_g(ds.unsqueeze(0)).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
624
- x, x_mask = self.enc_p(phone, pitch, phone_lengths)
625
- x = self.flow(x, x_mask, g=g, reverse=True)
626
- o = self.dec((x*x_mask)[:, :, :max_len], pitchf, g=g)
627
- return o
628
-
629
- class MultiPeriodDiscriminator(torch.nn.Module):
630
- def __init__(self, use_spectral_norm=False):
631
- super(MultiPeriodDiscriminator, self).__init__()
632
- periods = [2, 3, 5, 7, 11,17]
633
- # periods = [3, 5, 7, 11, 17, 23, 37]
634
-
635
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
636
- discs = discs + [
637
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
638
- ]
639
- self.discriminators = nn.ModuleList(discs)
640
-
641
- def forward(self, y, y_hat):
642
- y_d_rs = []#
643
- y_d_gs = []
644
- fmap_rs = []
645
- fmap_gs = []
646
- for i, d in enumerate(self.discriminators):
647
- y_d_r, fmap_r = d(y)
648
- y_d_g, fmap_g = d(y_hat)
649
- # for j in range(len(fmap_r)):
650
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
651
- y_d_rs.append(y_d_r)
652
- y_d_gs.append(y_d_g)
653
- fmap_rs.append(fmap_r)
654
- fmap_gs.append(fmap_g)
655
-
656
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
657
-
658
- class DiscriminatorS(torch.nn.Module):
659
- def __init__(self, use_spectral_norm=False):
660
- super(DiscriminatorS, self).__init__()
661
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
662
- self.convs = nn.ModuleList(
663
- [
664
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
665
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
666
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
667
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
668
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
669
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
670
- ]
671
- )
672
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
673
-
674
- def forward(self, x):
675
- fmap = []
676
-
677
- for l in self.convs:
678
- x = l(x)
679
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
680
- fmap.append(x)
681
- x = self.conv_post(x)
682
- fmap.append(x)
683
- x = torch.flatten(x, 1, -1)
684
-
685
- return x, fmap
686
-
687
- class DiscriminatorP(torch.nn.Module):
688
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
689
- super(DiscriminatorP, self).__init__()
690
- self.period = period
691
- self.use_spectral_norm = use_spectral_norm
692
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
693
- self.convs = nn.ModuleList(
694
- [
695
- norm_f(
696
- Conv2d(
697
- 1,
698
- 32,
699
- (kernel_size, 1),
700
- (stride, 1),
701
- padding=(get_padding(kernel_size, 1), 0),
702
- )
703
- ),
704
- norm_f(
705
- Conv2d(
706
- 32,
707
- 128,
708
- (kernel_size, 1),
709
- (stride, 1),
710
- padding=(get_padding(kernel_size, 1), 0),
711
- )
712
- ),
713
- norm_f(
714
- Conv2d(
715
- 128,
716
- 512,
717
- (kernel_size, 1),
718
- (stride, 1),
719
- padding=(get_padding(kernel_size, 1), 0),
720
- )
721
- ),
722
- norm_f(
723
- Conv2d(
724
- 512,
725
- 1024,
726
- (kernel_size, 1),
727
- (stride, 1),
728
- padding=(get_padding(kernel_size, 1), 0),
729
- )
730
- ),
731
- norm_f(
732
- Conv2d(
733
- 1024,
734
- 1024,
735
- (kernel_size, 1),
736
- 1,
737
- padding=(get_padding(kernel_size, 1), 0),
738
- )
739
- ),
740
- ]
741
- )
742
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
743
-
744
- def forward(self, x):
745
- fmap = []
746
-
747
- # 1d to 2d
748
- b, c, t = x.shape
749
- if t % self.period != 0: # pad first
750
- n_pad = self.period - (t % self.period)
751
- x = F.pad(x, (0, n_pad), "reflect")
752
- t = t + n_pad
753
- x = x.view(b, c, t // self.period, self.period)
754
-
755
- for l in self.convs:
756
- x = l(x)
757
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
758
- fmap.append(x)
759
- x = self.conv_post(x)
760
- fmap.append(x)
761
- x = torch.flatten(x, 1, -1)
762
-
763
- return x, fmap
764
-
 
1
+ import math, pdb, os
2
+ from time import time as ttime
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from infer_pack import modules
7
+ from infer_pack import attentions
8
+ from infer_pack import commons
9
+ from infer_pack.commons import init_weights, get_padding
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from infer_pack.commons import init_weights
13
+ import numpy as np
14
+ from infer_pack import commons
15
+
16
+
17
+ class TextEncoder256(nn.Module):
18
+ def __init__(
19
+ self,
20
+ out_channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ f0=True,
28
+ ):
29
+ super().__init__()
30
+ self.out_channels = out_channels
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.emb_phone = nn.Linear(256, hidden_channels)
38
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
+ if f0 == True:
40
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
+ self.encoder = attentions.Encoder(
42
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
+ )
44
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
+
46
+ def forward(self, phone, pitch, lengths):
47
+ if pitch == None:
48
+ x = self.emb_phone(phone)
49
+ else:
50
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
+ x = self.lrelu(x)
53
+ x = torch.transpose(x, 1, -1) # [b, h, t]
54
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
+ x.dtype
56
+ )
57
+ x = self.encoder(x * x_mask, x_mask)
58
+ stats = self.proj(x) * x_mask
59
+
60
+ m, logs = torch.split(stats, self.out_channels, dim=1)
61
+ return m, logs, x_mask
62
+
63
+
64
+ class TextEncoder256Sim(nn.Module):
65
+ def __init__(
66
+ self,
67
+ out_channels,
68
+ hidden_channels,
69
+ filter_channels,
70
+ n_heads,
71
+ n_layers,
72
+ kernel_size,
73
+ p_dropout,
74
+ f0=True,
75
+ ):
76
+ super().__init__()
77
+ self.out_channels = out_channels
78
+ self.hidden_channels = hidden_channels
79
+ self.filter_channels = filter_channels
80
+ self.n_heads = n_heads
81
+ self.n_layers = n_layers
82
+ self.kernel_size = kernel_size
83
+ self.p_dropout = p_dropout
84
+ self.emb_phone = nn.Linear(256, hidden_channels)
85
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
+ if f0 == True:
87
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
+ self.encoder = attentions.Encoder(
89
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
+ )
91
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
92
+
93
+ def forward(self, phone, pitch, lengths):
94
+ if pitch == None:
95
+ x = self.emb_phone(phone)
96
+ else:
97
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
+ x = self.lrelu(x)
100
+ x = torch.transpose(x, 1, -1) # [b, h, t]
101
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
+ x.dtype
103
+ )
104
+ x = self.encoder(x * x_mask, x_mask)
105
+ x = self.proj(x) * x_mask
106
+ return x, x_mask
107
+
108
+
109
+ class ResidualCouplingBlock(nn.Module):
110
+ def __init__(
111
+ self,
112
+ channels,
113
+ hidden_channels,
114
+ kernel_size,
115
+ dilation_rate,
116
+ n_layers,
117
+ n_flows=4,
118
+ gin_channels=0,
119
+ ):
120
+ super().__init__()
121
+ self.channels = channels
122
+ self.hidden_channels = hidden_channels
123
+ self.kernel_size = kernel_size
124
+ self.dilation_rate = dilation_rate
125
+ self.n_layers = n_layers
126
+ self.n_flows = n_flows
127
+ self.gin_channels = gin_channels
128
+
129
+ self.flows = nn.ModuleList()
130
+ for i in range(n_flows):
131
+ self.flows.append(
132
+ modules.ResidualCouplingLayer(
133
+ channels,
134
+ hidden_channels,
135
+ kernel_size,
136
+ dilation_rate,
137
+ n_layers,
138
+ gin_channels=gin_channels,
139
+ mean_only=True,
140
+ )
141
+ )
142
+ self.flows.append(modules.Flip())
143
+
144
+ def forward(self, x, x_mask, g=None, reverse=False):
145
+ if not reverse:
146
+ for flow in self.flows:
147
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
148
+ else:
149
+ for flow in reversed(self.flows):
150
+ x = flow(x, x_mask, g=g, reverse=reverse)
151
+ return x
152
+
153
+ def remove_weight_norm(self):
154
+ for i in range(self.n_flows):
155
+ self.flows[i * 2].remove_weight_norm()
156
+
157
+
158
+ class PosteriorEncoder(nn.Module):
159
+ def __init__(
160
+ self,
161
+ in_channels,
162
+ out_channels,
163
+ hidden_channels,
164
+ kernel_size,
165
+ dilation_rate,
166
+ n_layers,
167
+ gin_channels=0,
168
+ ):
169
+ super().__init__()
170
+ self.in_channels = in_channels
171
+ self.out_channels = out_channels
172
+ self.hidden_channels = hidden_channels
173
+ self.kernel_size = kernel_size
174
+ self.dilation_rate = dilation_rate
175
+ self.n_layers = n_layers
176
+ self.gin_channels = gin_channels
177
+
178
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
179
+ self.enc = modules.WN(
180
+ hidden_channels,
181
+ kernel_size,
182
+ dilation_rate,
183
+ n_layers,
184
+ gin_channels=gin_channels,
185
+ )
186
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
187
+
188
+ def forward(self, x, x_lengths, g=None):
189
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
190
+ x.dtype
191
+ )
192
+ x = self.pre(x) * x_mask
193
+ x = self.enc(x, x_mask, g=g)
194
+ stats = self.proj(x) * x_mask
195
+ m, logs = torch.split(stats, self.out_channels, dim=1)
196
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
197
+ return z, m, logs, x_mask
198
+
199
+ def remove_weight_norm(self):
200
+ self.enc.remove_weight_norm()
201
+
202
+
203
+ class Generator(torch.nn.Module):
204
+ def __init__(
205
+ self,
206
+ initial_channel,
207
+ resblock,
208
+ resblock_kernel_sizes,
209
+ resblock_dilation_sizes,
210
+ upsample_rates,
211
+ upsample_initial_channel,
212
+ upsample_kernel_sizes,
213
+ gin_channels=0,
214
+ ):
215
+ super(Generator, self).__init__()
216
+ self.num_kernels = len(resblock_kernel_sizes)
217
+ self.num_upsamples = len(upsample_rates)
218
+ self.conv_pre = Conv1d(
219
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
220
+ )
221
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
222
+
223
+ self.ups = nn.ModuleList()
224
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
225
+ self.ups.append(
226
+ weight_norm(
227
+ ConvTranspose1d(
228
+ upsample_initial_channel // (2**i),
229
+ upsample_initial_channel // (2 ** (i + 1)),
230
+ k,
231
+ u,
232
+ padding=(k - u) // 2,
233
+ )
234
+ )
235
+ )
236
+
237
+ self.resblocks = nn.ModuleList()
238
+ for i in range(len(self.ups)):
239
+ ch = upsample_initial_channel // (2 ** (i + 1))
240
+ for j, (k, d) in enumerate(
241
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
242
+ ):
243
+ self.resblocks.append(resblock(ch, k, d))
244
+
245
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
246
+ self.ups.apply(init_weights)
247
+
248
+ if gin_channels != 0:
249
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
250
+
251
+ def forward(self, x, g=None):
252
+ x = self.conv_pre(x)
253
+ if g is not None:
254
+ x = x + self.cond(g)
255
+
256
+ for i in range(self.num_upsamples):
257
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
258
+ x = self.ups[i](x)
259
+ xs = None
260
+ for j in range(self.num_kernels):
261
+ if xs is None:
262
+ xs = self.resblocks[i * self.num_kernels + j](x)
263
+ else:
264
+ xs += self.resblocks[i * self.num_kernels + j](x)
265
+ x = xs / self.num_kernels
266
+ x = F.leaky_relu(x)
267
+ x = self.conv_post(x)
268
+ x = torch.tanh(x)
269
+
270
+ return x
271
+
272
+ def remove_weight_norm(self):
273
+ for l in self.ups:
274
+ remove_weight_norm(l)
275
+ for l in self.resblocks:
276
+ l.remove_weight_norm()
277
+
278
+
279
+ class SineGen(torch.nn.Module):
280
+ """Definition of sine generator
281
+ SineGen(samp_rate, harmonic_num = 0,
282
+ sine_amp = 0.1, noise_std = 0.003,
283
+ voiced_threshold = 0,
284
+ flag_for_pulse=False)
285
+ samp_rate: sampling rate in Hz
286
+ harmonic_num: number of harmonic overtones (default 0)
287
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
288
+ noise_std: std of Gaussian noise (default 0.003)
289
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
290
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
291
+ Note: when flag_for_pulse is True, the first time step of a voiced
292
+ segment is always sin(np.pi) or cos(0)
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ samp_rate,
298
+ harmonic_num=0,
299
+ sine_amp=0.1,
300
+ noise_std=0.003,
301
+ voiced_threshold=0,
302
+ flag_for_pulse=False,
303
+ ):
304
+ super(SineGen, self).__init__()
305
+ self.sine_amp = sine_amp
306
+ self.noise_std = noise_std
307
+ self.harmonic_num = harmonic_num
308
+ self.dim = self.harmonic_num + 1
309
+ self.sampling_rate = samp_rate
310
+ self.voiced_threshold = voiced_threshold
311
+
312
+ def _f02uv(self, f0):
313
+ # generate uv signal
314
+ uv = torch.ones_like(f0)
315
+ uv = uv * (f0 > self.voiced_threshold)
316
+ return uv
317
+
318
+ def forward(self, f0, upp):
319
+ """sine_tensor, uv = forward(f0)
320
+ input F0: tensor(batchsize=1, length, dim=1)
321
+ f0 for unvoiced steps should be 0
322
+ output sine_tensor: tensor(batchsize=1, length, dim)
323
+ output uv: tensor(batchsize=1, length, 1)
324
+ """
325
+ with torch.no_grad():
326
+ f0 = f0[:, None].transpose(1, 2)
327
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
328
+ # fundamental component
329
+ f0_buf[:, :, 0] = f0[:, :, 0]
330
+ for idx in np.arange(self.harmonic_num):
331
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
332
+ idx + 2
333
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
334
+ rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
335
+ rand_ini = torch.rand(
336
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
337
+ )
338
+ rand_ini[:, 0] = 0
339
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
340
+ tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
341
+ tmp_over_one *= upp
342
+ tmp_over_one = F.interpolate(
343
+ tmp_over_one.transpose(2, 1),
344
+ scale_factor=upp,
345
+ mode="linear",
346
+ align_corners=True,
347
+ ).transpose(2, 1)
348
+ rad_values = F.interpolate(
349
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
350
+ ).transpose(
351
+ 2, 1
352
+ ) #######
353
+ tmp_over_one %= 1
354
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
355
+ cumsum_shift = torch.zeros_like(rad_values)
356
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
357
+ sine_waves = torch.sin(
358
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
359
+ )
360
+ sine_waves = sine_waves * self.sine_amp
361
+ uv = self._f02uv(f0)
362
+ uv = F.interpolate(
363
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
364
+ ).transpose(2, 1)
365
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
366
+ noise = noise_amp * torch.randn_like(sine_waves)
367
+ sine_waves = sine_waves * uv + noise
368
+ return sine_waves, uv, noise
369
+
370
+
371
+ class SourceModuleHnNSF(torch.nn.Module):
372
+ """SourceModule for hn-nsf
373
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
374
+ add_noise_std=0.003, voiced_threshod=0)
375
+ sampling_rate: sampling_rate in Hz
376
+ harmonic_num: number of harmonic above F0 (default: 0)
377
+ sine_amp: amplitude of sine source signal (default: 0.1)
378
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
379
+ note that amplitude of noise in unvoiced is decided
380
+ by sine_amp
381
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
382
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
383
+ F0_sampled (batchsize, length, 1)
384
+ Sine_source (batchsize, length, 1)
385
+ noise_source (batchsize, length 1)
386
+ uv (batchsize, length, 1)
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ sampling_rate,
392
+ harmonic_num=0,
393
+ sine_amp=0.1,
394
+ add_noise_std=0.003,
395
+ voiced_threshod=0,
396
+ is_half=True,
397
+ ):
398
+ super(SourceModuleHnNSF, self).__init__()
399
+
400
+ self.sine_amp = sine_amp
401
+ self.noise_std = add_noise_std
402
+ self.is_half = is_half
403
+ # to produce sine waveforms
404
+ self.l_sin_gen = SineGen(
405
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
406
+ )
407
+
408
+ # to merge source harmonics into a single excitation
409
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
410
+ self.l_tanh = torch.nn.Tanh()
411
+
412
+ def forward(self, x, upp=None):
413
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
414
+ if self.is_half:
415
+ sine_wavs = sine_wavs.half()
416
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
417
+ return sine_merge, None, None # noise, uv
418
+
419
+
420
+ class GeneratorNSF(torch.nn.Module):
421
+ def __init__(
422
+ self,
423
+ initial_channel,
424
+ resblock,
425
+ resblock_kernel_sizes,
426
+ resblock_dilation_sizes,
427
+ upsample_rates,
428
+ upsample_initial_channel,
429
+ upsample_kernel_sizes,
430
+ gin_channels,
431
+ sr,
432
+ is_half=False,
433
+ ):
434
+ super(GeneratorNSF, self).__init__()
435
+ self.num_kernels = len(resblock_kernel_sizes)
436
+ self.num_upsamples = len(upsample_rates)
437
+
438
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
439
+ self.m_source = SourceModuleHnNSF(
440
+ sampling_rate=sr, harmonic_num=0, is_half=is_half
441
+ )
442
+ self.noise_convs = nn.ModuleList()
443
+ self.conv_pre = Conv1d(
444
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
445
+ )
446
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
447
+
448
+ self.ups = nn.ModuleList()
449
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
450
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
451
+ self.ups.append(
452
+ weight_norm(
453
+ ConvTranspose1d(
454
+ upsample_initial_channel // (2**i),
455
+ upsample_initial_channel // (2 ** (i + 1)),
456
+ k,
457
+ u,
458
+ padding=(k - u) // 2,
459
+ )
460
+ )
461
+ )
462
+ if i + 1 < len(upsample_rates):
463
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
464
+ self.noise_convs.append(
465
+ Conv1d(
466
+ 1,
467
+ c_cur,
468
+ kernel_size=stride_f0 * 2,
469
+ stride=stride_f0,
470
+ padding=stride_f0 // 2,
471
+ )
472
+ )
473
+ else:
474
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
475
+
476
+ self.resblocks = nn.ModuleList()
477
+ for i in range(len(self.ups)):
478
+ ch = upsample_initial_channel // (2 ** (i + 1))
479
+ for j, (k, d) in enumerate(
480
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
481
+ ):
482
+ self.resblocks.append(resblock(ch, k, d))
483
+
484
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
485
+ self.ups.apply(init_weights)
486
+
487
+ if gin_channels != 0:
488
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
489
+
490
+ self.upp = np.prod(upsample_rates)
491
+
492
+ def forward(self, x, f0, g=None):
493
+ har_source, noi_source, uv = self.m_source(f0, self.upp)
494
+ har_source = har_source.transpose(1, 2)
495
+ x = self.conv_pre(x)
496
+ if g is not None:
497
+ x = x + self.cond(g)
498
+
499
+ for i in range(self.num_upsamples):
500
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
501
+ x = self.ups[i](x)
502
+ x_source = self.noise_convs[i](har_source)
503
+ x = x + x_source
504
+ xs = None
505
+ for j in range(self.num_kernels):
506
+ if xs is None:
507
+ xs = self.resblocks[i * self.num_kernels + j](x)
508
+ else:
509
+ xs += self.resblocks[i * self.num_kernels + j](x)
510
+ x = xs / self.num_kernels
511
+ x = F.leaky_relu(x)
512
+ x = self.conv_post(x)
513
+ x = torch.tanh(x)
514
+ return x
515
+
516
+ def remove_weight_norm(self):
517
+ for l in self.ups:
518
+ remove_weight_norm(l)
519
+ for l in self.resblocks:
520
+ l.remove_weight_norm()
521
+
522
+
523
+ sr2sr = {
524
+ "32k": 32000,
525
+ "40k": 40000,
526
+ "48k": 48000,
527
+ }
528
+
529
+
530
+ class SynthesizerTrnMs256NSFsidO(nn.Module):
531
+ def __init__(
532
+ self,
533
+ spec_channels,
534
+ segment_size,
535
+ inter_channels,
536
+ hidden_channels,
537
+ filter_channels,
538
+ n_heads,
539
+ n_layers,
540
+ kernel_size,
541
+ p_dropout,
542
+ resblock,
543
+ resblock_kernel_sizes,
544
+ resblock_dilation_sizes,
545
+ upsample_rates,
546
+ upsample_initial_channel,
547
+ upsample_kernel_sizes,
548
+ spk_embed_dim,
549
+ gin_channels,
550
+ sr,
551
+ **kwargs
552
+ ):
553
+ super().__init__()
554
+ if type(sr) == type("strr"):
555
+ sr = sr2sr[sr]
556
+ self.spec_channels = spec_channels
557
+ self.inter_channels = inter_channels
558
+ self.hidden_channels = hidden_channels
559
+ self.filter_channels = filter_channels
560
+ self.n_heads = n_heads
561
+ self.n_layers = n_layers
562
+ self.kernel_size = kernel_size
563
+ self.p_dropout = p_dropout
564
+ self.resblock = resblock
565
+ self.resblock_kernel_sizes = resblock_kernel_sizes
566
+ self.resblock_dilation_sizes = resblock_dilation_sizes
567
+ self.upsample_rates = upsample_rates
568
+ self.upsample_initial_channel = upsample_initial_channel
569
+ self.upsample_kernel_sizes = upsample_kernel_sizes
570
+ self.segment_size = segment_size
571
+ self.gin_channels = gin_channels
572
+ # self.hop_length = hop_length#
573
+ self.spk_embed_dim = spk_embed_dim
574
+ self.enc_p = TextEncoder256(
575
+ inter_channels,
576
+ hidden_channels,
577
+ filter_channels,
578
+ n_heads,
579
+ n_layers,
580
+ kernel_size,
581
+ p_dropout,
582
+ )
583
+ self.dec = GeneratorNSF(
584
+ inter_channels,
585
+ resblock,
586
+ resblock_kernel_sizes,
587
+ resblock_dilation_sizes,
588
+ upsample_rates,
589
+ upsample_initial_channel,
590
+ upsample_kernel_sizes,
591
+ gin_channels=gin_channels,
592
+ sr=sr,
593
+ is_half=kwargs["is_half"],
594
+ )
595
+ self.enc_q = PosteriorEncoder(
596
+ spec_channels,
597
+ inter_channels,
598
+ hidden_channels,
599
+ 5,
600
+ 1,
601
+ 16,
602
+ gin_channels=gin_channels,
603
+ )
604
+ self.flow = ResidualCouplingBlock(
605
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
606
+ )
607
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
608
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
609
+
610
+ def remove_weight_norm(self):
611
+ self.dec.remove_weight_norm()
612
+ self.flow.remove_weight_norm()
613
+ self.enc_q.remove_weight_norm()
614
+
615
+ def forward(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
616
+ g = self.emb_g(sid).unsqueeze(-1)
617
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
618
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
619
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
620
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
621
+ return o
622
+
623
+
624
+ class MultiPeriodDiscriminator(torch.nn.Module):
625
+ def __init__(self, use_spectral_norm=False):
626
+ super(MultiPeriodDiscriminator, self).__init__()
627
+ periods = [2, 3, 5, 7, 11, 17]
628
+ # periods = [3, 5, 7, 11, 17, 23, 37]
629
+
630
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
631
+ discs = discs + [
632
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
633
+ ]
634
+ self.discriminators = nn.ModuleList(discs)
635
+
636
+ def forward(self, y, y_hat):
637
+ y_d_rs = [] #
638
+ y_d_gs = []
639
+ fmap_rs = []
640
+ fmap_gs = []
641
+ for i, d in enumerate(self.discriminators):
642
+ y_d_r, fmap_r = d(y)
643
+ y_d_g, fmap_g = d(y_hat)
644
+ # for j in range(len(fmap_r)):
645
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
646
+ y_d_rs.append(y_d_r)
647
+ y_d_gs.append(y_d_g)
648
+ fmap_rs.append(fmap_r)
649
+ fmap_gs.append(fmap_g)
650
+
651
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
652
+
653
+
654
+ class DiscriminatorS(torch.nn.Module):
655
+ def __init__(self, use_spectral_norm=False):
656
+ super(DiscriminatorS, self).__init__()
657
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
658
+ self.convs = nn.ModuleList(
659
+ [
660
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
661
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
662
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
663
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
664
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
665
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
666
+ ]
667
+ )
668
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
669
+
670
+ def forward(self, x):
671
+ fmap = []
672
+
673
+ for l in self.convs:
674
+ x = l(x)
675
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
676
+ fmap.append(x)
677
+ x = self.conv_post(x)
678
+ fmap.append(x)
679
+ x = torch.flatten(x, 1, -1)
680
+
681
+ return x, fmap
682
+
683
+
684
+ class DiscriminatorP(torch.nn.Module):
685
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
686
+ super(DiscriminatorP, self).__init__()
687
+ self.period = period
688
+ self.use_spectral_norm = use_spectral_norm
689
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
690
+ self.convs = nn.ModuleList(
691
+ [
692
+ norm_f(
693
+ Conv2d(
694
+ 1,
695
+ 32,
696
+ (kernel_size, 1),
697
+ (stride, 1),
698
+ padding=(get_padding(kernel_size, 1), 0),
699
+ )
700
+ ),
701
+ norm_f(
702
+ Conv2d(
703
+ 32,
704
+ 128,
705
+ (kernel_size, 1),
706
+ (stride, 1),
707
+ padding=(get_padding(kernel_size, 1), 0),
708
+ )
709
+ ),
710
+ norm_f(
711
+ Conv2d(
712
+ 128,
713
+ 512,
714
+ (kernel_size, 1),
715
+ (stride, 1),
716
+ padding=(get_padding(kernel_size, 1), 0),
717
+ )
718
+ ),
719
+ norm_f(
720
+ Conv2d(
721
+ 512,
722
+ 1024,
723
+ (kernel_size, 1),
724
+ (stride, 1),
725
+ padding=(get_padding(kernel_size, 1), 0),
726
+ )
727
+ ),
728
+ norm_f(
729
+ Conv2d(
730
+ 1024,
731
+ 1024,
732
+ (kernel_size, 1),
733
+ 1,
734
+ padding=(get_padding(kernel_size, 1), 0),
735
+ )
736
+ ),
737
+ ]
738
+ )
739
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
740
+
741
+ def forward(self, x):
742
+ fmap = []
743
+
744
+ # 1d to 2d
745
+ b, c, t = x.shape
746
+ if t % self.period != 0: # pad first
747
+ n_pad = self.period - (t % self.period)
748
+ x = F.pad(x, (0, n_pad), "reflect")
749
+ t = t + n_pad
750
+ x = x.view(b, c, t // self.period, self.period)
751
+
752
+ for l in self.convs:
753
+ x = l(x)
754
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
755
+ fmap.append(x)
756
+ x = self.conv_post(x)
757
+ fmap.append(x)
758
+ x = torch.flatten(x, 1, -1)
759
+
760
+ return x, fmap
 
 
 
 
infer_pack/models_onnx_moess.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, pdb, os
2
+ from time import time as ttime
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from infer_pack import modules
7
+ from infer_pack import attentions
8
+ from infer_pack import commons
9
+ from infer_pack.commons import init_weights, get_padding
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from infer_pack.commons import init_weights
13
+ import numpy as np
14
+ from infer_pack import commons
15
+
16
+
17
+ class TextEncoder256(nn.Module):
18
+ def __init__(
19
+ self,
20
+ out_channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ f0=True,
28
+ ):
29
+ super().__init__()
30
+ self.out_channels = out_channels
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.emb_phone = nn.Linear(256, hidden_channels)
38
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
+ if f0 == True:
40
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
+ self.encoder = attentions.Encoder(
42
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
+ )
44
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
+
46
+ def forward(self, phone, pitch, lengths):
47
+ if pitch == None:
48
+ x = self.emb_phone(phone)
49
+ else:
50
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
+ x = self.lrelu(x)
53
+ x = torch.transpose(x, 1, -1) # [b, h, t]
54
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
+ x.dtype
56
+ )
57
+ x = self.encoder(x * x_mask, x_mask)
58
+ stats = self.proj(x) * x_mask
59
+
60
+ m, logs = torch.split(stats, self.out_channels, dim=1)
61
+ return m, logs, x_mask
62
+
63
+
64
+ class TextEncoder256Sim(nn.Module):
65
+ def __init__(
66
+ self,
67
+ out_channels,
68
+ hidden_channels,
69
+ filter_channels,
70
+ n_heads,
71
+ n_layers,
72
+ kernel_size,
73
+ p_dropout,
74
+ f0=True,
75
+ ):
76
+ super().__init__()
77
+ self.out_channels = out_channels
78
+ self.hidden_channels = hidden_channels
79
+ self.filter_channels = filter_channels
80
+ self.n_heads = n_heads
81
+ self.n_layers = n_layers
82
+ self.kernel_size = kernel_size
83
+ self.p_dropout = p_dropout
84
+ self.emb_phone = nn.Linear(256, hidden_channels)
85
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
+ if f0 == True:
87
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
+ self.encoder = attentions.Encoder(
89
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
+ )
91
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
92
+
93
+ def forward(self, phone, pitch, lengths):
94
+ if pitch == None:
95
+ x = self.emb_phone(phone)
96
+ else:
97
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
+ x = self.lrelu(x)
100
+ x = torch.transpose(x, 1, -1) # [b, h, t]
101
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
+ x.dtype
103
+ )
104
+ x = self.encoder(x * x_mask, x_mask)
105
+ x = self.proj(x) * x_mask
106
+ return x, x_mask
107
+
108
+
109
+ class ResidualCouplingBlock(nn.Module):
110
+ def __init__(
111
+ self,
112
+ channels,
113
+ hidden_channels,
114
+ kernel_size,
115
+ dilation_rate,
116
+ n_layers,
117
+ n_flows=4,
118
+ gin_channels=0,
119
+ ):
120
+ super().__init__()
121
+ self.channels = channels
122
+ self.hidden_channels = hidden_channels
123
+ self.kernel_size = kernel_size
124
+ self.dilation_rate = dilation_rate
125
+ self.n_layers = n_layers
126
+ self.n_flows = n_flows
127
+ self.gin_channels = gin_channels
128
+
129
+ self.flows = nn.ModuleList()
130
+ for i in range(n_flows):
131
+ self.flows.append(
132
+ modules.ResidualCouplingLayer(
133
+ channels,
134
+ hidden_channels,
135
+ kernel_size,
136
+ dilation_rate,
137
+ n_layers,
138
+ gin_channels=gin_channels,
139
+ mean_only=True,
140
+ )
141
+ )
142
+ self.flows.append(modules.Flip())
143
+
144
+ def forward(self, x, x_mask, g=None, reverse=False):
145
+ if not reverse:
146
+ for flow in self.flows:
147
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
148
+ else:
149
+ for flow in reversed(self.flows):
150
+ x = flow(x, x_mask, g=g, reverse=reverse)
151
+ return x
152
+
153
+ def remove_weight_norm(self):
154
+ for i in range(self.n_flows):
155
+ self.flows[i * 2].remove_weight_norm()
156
+
157
+
158
+ class PosteriorEncoder(nn.Module):
159
+ def __init__(
160
+ self,
161
+ in_channels,
162
+ out_channels,
163
+ hidden_channels,
164
+ kernel_size,
165
+ dilation_rate,
166
+ n_layers,
167
+ gin_channels=0,
168
+ ):
169
+ super().__init__()
170
+ self.in_channels = in_channels
171
+ self.out_channels = out_channels
172
+ self.hidden_channels = hidden_channels
173
+ self.kernel_size = kernel_size
174
+ self.dilation_rate = dilation_rate
175
+ self.n_layers = n_layers
176
+ self.gin_channels = gin_channels
177
+
178
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
179
+ self.enc = modules.WN(
180
+ hidden_channels,
181
+ kernel_size,
182
+ dilation_rate,
183
+ n_layers,
184
+ gin_channels=gin_channels,
185
+ )
186
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
187
+
188
+ def forward(self, x, x_lengths, g=None):
189
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
190
+ x.dtype
191
+ )
192
+ x = self.pre(x) * x_mask
193
+ x = self.enc(x, x_mask, g=g)
194
+ stats = self.proj(x) * x_mask
195
+ m, logs = torch.split(stats, self.out_channels, dim=1)
196
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
197
+ return z, m, logs, x_mask
198
+
199
+ def remove_weight_norm(self):
200
+ self.enc.remove_weight_norm()
201
+
202
+
203
+ class Generator(torch.nn.Module):
204
+ def __init__(
205
+ self,
206
+ initial_channel,
207
+ resblock,
208
+ resblock_kernel_sizes,
209
+ resblock_dilation_sizes,
210
+ upsample_rates,
211
+ upsample_initial_channel,
212
+ upsample_kernel_sizes,
213
+ gin_channels=0,
214
+ ):
215
+ super(Generator, self).__init__()
216
+ self.num_kernels = len(resblock_kernel_sizes)
217
+ self.num_upsamples = len(upsample_rates)
218
+ self.conv_pre = Conv1d(
219
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
220
+ )
221
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
222
+
223
+ self.ups = nn.ModuleList()
224
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
225
+ self.ups.append(
226
+ weight_norm(
227
+ ConvTranspose1d(
228
+ upsample_initial_channel // (2**i),
229
+ upsample_initial_channel // (2 ** (i + 1)),
230
+ k,
231
+ u,
232
+ padding=(k - u) // 2,
233
+ )
234
+ )
235
+ )
236
+
237
+ self.resblocks = nn.ModuleList()
238
+ for i in range(len(self.ups)):
239
+ ch = upsample_initial_channel // (2 ** (i + 1))
240
+ for j, (k, d) in enumerate(
241
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
242
+ ):
243
+ self.resblocks.append(resblock(ch, k, d))
244
+
245
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
246
+ self.ups.apply(init_weights)
247
+
248
+ if gin_channels != 0:
249
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
250
+
251
+ def forward(self, x, g=None):
252
+ x = self.conv_pre(x)
253
+ if g is not None:
254
+ x = x + self.cond(g)
255
+
256
+ for i in range(self.num_upsamples):
257
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
258
+ x = self.ups[i](x)
259
+ xs = None
260
+ for j in range(self.num_kernels):
261
+ if xs is None:
262
+ xs = self.resblocks[i * self.num_kernels + j](x)
263
+ else:
264
+ xs += self.resblocks[i * self.num_kernels + j](x)
265
+ x = xs / self.num_kernels
266
+ x = F.leaky_relu(x)
267
+ x = self.conv_post(x)
268
+ x = torch.tanh(x)
269
+
270
+ return x
271
+
272
+ def remove_weight_norm(self):
273
+ for l in self.ups:
274
+ remove_weight_norm(l)
275
+ for l in self.resblocks:
276
+ l.remove_weight_norm()
277
+
278
+
279
+ class SineGen(torch.nn.Module):
280
+ """Definition of sine generator
281
+ SineGen(samp_rate, harmonic_num = 0,
282
+ sine_amp = 0.1, noise_std = 0.003,
283
+ voiced_threshold = 0,
284
+ flag_for_pulse=False)
285
+ samp_rate: sampling rate in Hz
286
+ harmonic_num: number of harmonic overtones (default 0)
287
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
288
+ noise_std: std of Gaussian noise (default 0.003)
289
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
290
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
291
+ Note: when flag_for_pulse is True, the first time step of a voiced
292
+ segment is always sin(np.pi) or cos(0)
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ samp_rate,
298
+ harmonic_num=0,
299
+ sine_amp=0.1,
300
+ noise_std=0.003,
301
+ voiced_threshold=0,
302
+ flag_for_pulse=False,
303
+ ):
304
+ super(SineGen, self).__init__()
305
+ self.sine_amp = sine_amp
306
+ self.noise_std = noise_std
307
+ self.harmonic_num = harmonic_num
308
+ self.dim = self.harmonic_num + 1
309
+ self.sampling_rate = samp_rate
310
+ self.voiced_threshold = voiced_threshold
311
+
312
+ def _f02uv(self, f0):
313
+ # generate uv signal
314
+ uv = torch.ones_like(f0)
315
+ uv = uv * (f0 > self.voiced_threshold)
316
+ return uv
317
+
318
+ def forward(self, f0, upp):
319
+ """sine_tensor, uv = forward(f0)
320
+ input F0: tensor(batchsize=1, length, dim=1)
321
+ f0 for unvoiced steps should be 0
322
+ output sine_tensor: tensor(batchsize=1, length, dim)
323
+ output uv: tensor(batchsize=1, length, 1)
324
+ """
325
+ with torch.no_grad():
326
+ f0 = f0[:, None].transpose(1, 2)
327
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
328
+ # fundamental component
329
+ f0_buf[:, :, 0] = f0[:, :, 0]
330
+ for idx in np.arange(self.harmonic_num):
331
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
332
+ idx + 2
333
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
334
+ rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
335
+ rand_ini = torch.rand(
336
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
337
+ )
338
+ rand_ini[:, 0] = 0
339
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
340
+ tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
341
+ tmp_over_one *= upp
342
+ tmp_over_one = F.interpolate(
343
+ tmp_over_one.transpose(2, 1),
344
+ scale_factor=upp,
345
+ mode="linear",
346
+ align_corners=True,
347
+ ).transpose(2, 1)
348
+ rad_values = F.interpolate(
349
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
350
+ ).transpose(
351
+ 2, 1
352
+ ) #######
353
+ tmp_over_one %= 1
354
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
355
+ cumsum_shift = torch.zeros_like(rad_values)
356
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
357
+ sine_waves = torch.sin(
358
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
359
+ )
360
+ sine_waves = sine_waves * self.sine_amp
361
+ uv = self._f02uv(f0)
362
+ uv = F.interpolate(
363
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
364
+ ).transpose(2, 1)
365
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
366
+ noise = noise_amp * torch.randn_like(sine_waves)
367
+ sine_waves = sine_waves * uv + noise
368
+ return sine_waves, uv, noise
369
+
370
+
371
+ class SourceModuleHnNSF(torch.nn.Module):
372
+ """SourceModule for hn-nsf
373
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
374
+ add_noise_std=0.003, voiced_threshod=0)
375
+ sampling_rate: sampling_rate in Hz
376
+ harmonic_num: number of harmonic above F0 (default: 0)
377
+ sine_amp: amplitude of sine source signal (default: 0.1)
378
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
379
+ note that amplitude of noise in unvoiced is decided
380
+ by sine_amp
381
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
382
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
383
+ F0_sampled (batchsize, length, 1)
384
+ Sine_source (batchsize, length, 1)
385
+ noise_source (batchsize, length 1)
386
+ uv (batchsize, length, 1)
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ sampling_rate,
392
+ harmonic_num=0,
393
+ sine_amp=0.1,
394
+ add_noise_std=0.003,
395
+ voiced_threshod=0,
396
+ is_half=True,
397
+ ):
398
+ super(SourceModuleHnNSF, self).__init__()
399
+
400
+ self.sine_amp = sine_amp
401
+ self.noise_std = add_noise_std
402
+ self.is_half = is_half
403
+ # to produce sine waveforms
404
+ self.l_sin_gen = SineGen(
405
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
406
+ )
407
+
408
+ # to merge source harmonics into a single excitation
409
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
410
+ self.l_tanh = torch.nn.Tanh()
411
+
412
+ def forward(self, x, upp=None):
413
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
414
+ if self.is_half:
415
+ sine_wavs = sine_wavs.half()
416
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
417
+ return sine_merge, None, None # noise, uv
418
+
419
+
420
+ class GeneratorNSF(torch.nn.Module):
421
+ def __init__(
422
+ self,
423
+ initial_channel,
424
+ resblock,
425
+ resblock_kernel_sizes,
426
+ resblock_dilation_sizes,
427
+ upsample_rates,
428
+ upsample_initial_channel,
429
+ upsample_kernel_sizes,
430
+ gin_channels,
431
+ sr,
432
+ is_half=False,
433
+ ):
434
+ super(GeneratorNSF, self).__init__()
435
+ self.num_kernels = len(resblock_kernel_sizes)
436
+ self.num_upsamples = len(upsample_rates)
437
+
438
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
439
+ self.m_source = SourceModuleHnNSF(
440
+ sampling_rate=sr, harmonic_num=0, is_half=is_half
441
+ )
442
+ self.noise_convs = nn.ModuleList()
443
+ self.conv_pre = Conv1d(
444
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
445
+ )
446
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
447
+
448
+ self.ups = nn.ModuleList()
449
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
450
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
451
+ self.ups.append(
452
+ weight_norm(
453
+ ConvTranspose1d(
454
+ upsample_initial_channel // (2**i),
455
+ upsample_initial_channel // (2 ** (i + 1)),
456
+ k,
457
+ u,
458
+ padding=(k - u) // 2,
459
+ )
460
+ )
461
+ )
462
+ if i + 1 < len(upsample_rates):
463
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
464
+ self.noise_convs.append(
465
+ Conv1d(
466
+ 1,
467
+ c_cur,
468
+ kernel_size=stride_f0 * 2,
469
+ stride=stride_f0,
470
+ padding=stride_f0 // 2,
471
+ )
472
+ )
473
+ else:
474
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
475
+
476
+ self.resblocks = nn.ModuleList()
477
+ for i in range(len(self.ups)):
478
+ ch = upsample_initial_channel // (2 ** (i + 1))
479
+ for j, (k, d) in enumerate(
480
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
481
+ ):
482
+ self.resblocks.append(resblock(ch, k, d))
483
+
484
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
485
+ self.ups.apply(init_weights)
486
+
487
+ if gin_channels != 0:
488
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
489
+
490
+ self.upp = np.prod(upsample_rates)
491
+
492
+ def forward(self, x, f0, g=None):
493
+ har_source, noi_source, uv = self.m_source(f0, self.upp)
494
+ har_source = har_source.transpose(1, 2)
495
+ x = self.conv_pre(x)
496
+ if g is not None:
497
+ x = x + self.cond(g)
498
+
499
+ for i in range(self.num_upsamples):
500
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
501
+ x = self.ups[i](x)
502
+ x_source = self.noise_convs[i](har_source)
503
+ x = x + x_source
504
+ xs = None
505
+ for j in range(self.num_kernels):
506
+ if xs is None:
507
+ xs = self.resblocks[i * self.num_kernels + j](x)
508
+ else:
509
+ xs += self.resblocks[i * self.num_kernels + j](x)
510
+ x = xs / self.num_kernels
511
+ x = F.leaky_relu(x)
512
+ x = self.conv_post(x)
513
+ x = torch.tanh(x)
514
+ return x
515
+
516
+ def remove_weight_norm(self):
517
+ for l in self.ups:
518
+ remove_weight_norm(l)
519
+ for l in self.resblocks:
520
+ l.remove_weight_norm()
521
+
522
+
523
+ sr2sr = {
524
+ "32k": 32000,
525
+ "40k": 40000,
526
+ "48k": 48000,
527
+ }
528
+
529
+
530
+ class SynthesizerTrnMs256NSFsidM(nn.Module):
531
+ def __init__(
532
+ self,
533
+ spec_channels,
534
+ segment_size,
535
+ inter_channels,
536
+ hidden_channels,
537
+ filter_channels,
538
+ n_heads,
539
+ n_layers,
540
+ kernel_size,
541
+ p_dropout,
542
+ resblock,
543
+ resblock_kernel_sizes,
544
+ resblock_dilation_sizes,
545
+ upsample_rates,
546
+ upsample_initial_channel,
547
+ upsample_kernel_sizes,
548
+ spk_embed_dim,
549
+ gin_channels,
550
+ sr,
551
+ **kwargs
552
+ ):
553
+ super().__init__()
554
+ if type(sr) == type("strr"):
555
+ sr = sr2sr[sr]
556
+ self.spec_channels = spec_channels
557
+ self.inter_channels = inter_channels
558
+ self.hidden_channels = hidden_channels
559
+ self.filter_channels = filter_channels
560
+ self.n_heads = n_heads
561
+ self.n_layers = n_layers
562
+ self.kernel_size = kernel_size
563
+ self.p_dropout = p_dropout
564
+ self.resblock = resblock
565
+ self.resblock_kernel_sizes = resblock_kernel_sizes
566
+ self.resblock_dilation_sizes = resblock_dilation_sizes
567
+ self.upsample_rates = upsample_rates
568
+ self.upsample_initial_channel = upsample_initial_channel
569
+ self.upsample_kernel_sizes = upsample_kernel_sizes
570
+ self.segment_size = segment_size
571
+ self.gin_channels = gin_channels
572
+ # self.hop_length = hop_length#
573
+ self.spk_embed_dim = spk_embed_dim
574
+ self.enc_p = TextEncoder256(
575
+ inter_channels,
576
+ hidden_channels,
577
+ filter_channels,
578
+ n_heads,
579
+ n_layers,
580
+ kernel_size,
581
+ p_dropout,
582
+ )
583
+ self.dec = GeneratorNSF(
584
+ inter_channels,
585
+ resblock,
586
+ resblock_kernel_sizes,
587
+ resblock_dilation_sizes,
588
+ upsample_rates,
589
+ upsample_initial_channel,
590
+ upsample_kernel_sizes,
591
+ gin_channels=gin_channels,
592
+ sr=sr,
593
+ is_half=kwargs["is_half"],
594
+ )
595
+ self.enc_q = PosteriorEncoder(
596
+ spec_channels,
597
+ inter_channels,
598
+ hidden_channels,
599
+ 5,
600
+ 1,
601
+ 16,
602
+ gin_channels=gin_channels,
603
+ )
604
+ self.flow = ResidualCouplingBlock(
605
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
606
+ )
607
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
608
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
609
+
610
+ def remove_weight_norm(self):
611
+ self.dec.remove_weight_norm()
612
+ self.flow.remove_weight_norm()
613
+ self.enc_q.remove_weight_norm()
614
+
615
+ def forward(self, phone, phone_lengths, pitch, nsff0, sid, rnd, max_len=None):
616
+ g = self.emb_g(sid).unsqueeze(-1)
617
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
618
+ z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
619
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
620
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
621
+ return o
622
+
623
+
624
+ class SynthesizerTrnMs256NSFsid_sim(nn.Module):
625
+ """
626
+ Synthesizer for Training
627
+ """
628
+
629
+ def __init__(
630
+ self,
631
+ spec_channels,
632
+ segment_size,
633
+ inter_channels,
634
+ hidden_channels,
635
+ filter_channels,
636
+ n_heads,
637
+ n_layers,
638
+ kernel_size,
639
+ p_dropout,
640
+ resblock,
641
+ resblock_kernel_sizes,
642
+ resblock_dilation_sizes,
643
+ upsample_rates,
644
+ upsample_initial_channel,
645
+ upsample_kernel_sizes,
646
+ spk_embed_dim,
647
+ # hop_length,
648
+ gin_channels=0,
649
+ use_sdp=True,
650
+ **kwargs
651
+ ):
652
+ super().__init__()
653
+ self.spec_channels = spec_channels
654
+ self.inter_channels = inter_channels
655
+ self.hidden_channels = hidden_channels
656
+ self.filter_channels = filter_channels
657
+ self.n_heads = n_heads
658
+ self.n_layers = n_layers
659
+ self.kernel_size = kernel_size
660
+ self.p_dropout = p_dropout
661
+ self.resblock = resblock
662
+ self.resblock_kernel_sizes = resblock_kernel_sizes
663
+ self.resblock_dilation_sizes = resblock_dilation_sizes
664
+ self.upsample_rates = upsample_rates
665
+ self.upsample_initial_channel = upsample_initial_channel
666
+ self.upsample_kernel_sizes = upsample_kernel_sizes
667
+ self.segment_size = segment_size
668
+ self.gin_channels = gin_channels
669
+ # self.hop_length = hop_length#
670
+ self.spk_embed_dim = spk_embed_dim
671
+ self.enc_p = TextEncoder256Sim(
672
+ inter_channels,
673
+ hidden_channels,
674
+ filter_channels,
675
+ n_heads,
676
+ n_layers,
677
+ kernel_size,
678
+ p_dropout,
679
+ )
680
+ self.dec = GeneratorNSF(
681
+ inter_channels,
682
+ resblock,
683
+ resblock_kernel_sizes,
684
+ resblock_dilation_sizes,
685
+ upsample_rates,
686
+ upsample_initial_channel,
687
+ upsample_kernel_sizes,
688
+ gin_channels=gin_channels,
689
+ is_half=kwargs["is_half"],
690
+ )
691
+
692
+ self.flow = ResidualCouplingBlock(
693
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
694
+ )
695
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
696
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
697
+
698
+ def remove_weight_norm(self):
699
+ self.dec.remove_weight_norm()
700
+ self.flow.remove_weight_norm()
701
+ self.enc_q.remove_weight_norm()
702
+
703
+ def forward(
704
+ self, phone, phone_lengths, pitch, pitchf, ds, max_len=None
705
+ ): # y是spec不需要了现在
706
+ g = self.emb_g(ds.unsqueeze(0)).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
707
+ x, x_mask = self.enc_p(phone, pitch, phone_lengths)
708
+ x = self.flow(x, x_mask, g=g, reverse=True)
709
+ o = self.dec((x * x_mask)[:, :, :max_len], pitchf, g=g)
710
+ return o
711
+
712
+
713
+ class MultiPeriodDiscriminator(torch.nn.Module):
714
+ def __init__(self, use_spectral_norm=False):
715
+ super(MultiPeriodDiscriminator, self).__init__()
716
+ periods = [2, 3, 5, 7, 11, 17]
717
+ # periods = [3, 5, 7, 11, 17, 23, 37]
718
+
719
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
720
+ discs = discs + [
721
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
722
+ ]
723
+ self.discriminators = nn.ModuleList(discs)
724
+
725
+ def forward(self, y, y_hat):
726
+ y_d_rs = [] #
727
+ y_d_gs = []
728
+ fmap_rs = []
729
+ fmap_gs = []
730
+ for i, d in enumerate(self.discriminators):
731
+ y_d_r, fmap_r = d(y)
732
+ y_d_g, fmap_g = d(y_hat)
733
+ # for j in range(len(fmap_r)):
734
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
735
+ y_d_rs.append(y_d_r)
736
+ y_d_gs.append(y_d_g)
737
+ fmap_rs.append(fmap_r)
738
+ fmap_gs.append(fmap_g)
739
+
740
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
741
+
742
+
743
+ class DiscriminatorS(torch.nn.Module):
744
+ def __init__(self, use_spectral_norm=False):
745
+ super(DiscriminatorS, self).__init__()
746
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
747
+ self.convs = nn.ModuleList(
748
+ [
749
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
750
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
751
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
752
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
753
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
754
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
755
+ ]
756
+ )
757
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
758
+
759
+ def forward(self, x):
760
+ fmap = []
761
+
762
+ for l in self.convs:
763
+ x = l(x)
764
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
765
+ fmap.append(x)
766
+ x = self.conv_post(x)
767
+ fmap.append(x)
768
+ x = torch.flatten(x, 1, -1)
769
+
770
+ return x, fmap
771
+
772
+
773
+ class DiscriminatorP(torch.nn.Module):
774
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
775
+ super(DiscriminatorP, self).__init__()
776
+ self.period = period
777
+ self.use_spectral_norm = use_spectral_norm
778
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
779
+ self.convs = nn.ModuleList(
780
+ [
781
+ norm_f(
782
+ Conv2d(
783
+ 1,
784
+ 32,
785
+ (kernel_size, 1),
786
+ (stride, 1),
787
+ padding=(get_padding(kernel_size, 1), 0),
788
+ )
789
+ ),
790
+ norm_f(
791
+ Conv2d(
792
+ 32,
793
+ 128,
794
+ (kernel_size, 1),
795
+ (stride, 1),
796
+ padding=(get_padding(kernel_size, 1), 0),
797
+ )
798
+ ),
799
+ norm_f(
800
+ Conv2d(
801
+ 128,
802
+ 512,
803
+ (kernel_size, 1),
804
+ (stride, 1),
805
+ padding=(get_padding(kernel_size, 1), 0),
806
+ )
807
+ ),
808
+ norm_f(
809
+ Conv2d(
810
+ 512,
811
+ 1024,
812
+ (kernel_size, 1),
813
+ (stride, 1),
814
+ padding=(get_padding(kernel_size, 1), 0),
815
+ )
816
+ ),
817
+ norm_f(
818
+ Conv2d(
819
+ 1024,
820
+ 1024,
821
+ (kernel_size, 1),
822
+ 1,
823
+ padding=(get_padding(kernel_size, 1), 0),
824
+ )
825
+ ),
826
+ ]
827
+ )
828
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
829
+
830
+ def forward(self, x):
831
+ fmap = []
832
+
833
+ # 1d to 2d
834
+ b, c, t = x.shape
835
+ if t % self.period != 0: # pad first
836
+ n_pad = self.period - (t % self.period)
837
+ x = F.pad(x, (0, n_pad), "reflect")
838
+ t = t + n_pad
839
+ x = x.view(b, c, t // self.period, self.period)
840
+
841
+ for l in self.convs:
842
+ x = l(x)
843
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
844
+ fmap.append(x)
845
+ x = self.conv_post(x)
846
+ fmap.append(x)
847
+ x = torch.flatten(x, 1, -1)
848
+
849
+ return x, fmap
infer_pack/modules.py CHANGED
@@ -1,522 +1,522 @@
1
- import copy
2
- import math
3
- import numpy as np
4
- import scipy
5
- import torch
6
- from torch import nn
7
- from torch.nn import functional as F
8
-
9
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
- from torch.nn.utils import weight_norm, remove_weight_norm
11
-
12
- from infer_pack import commons
13
- from infer_pack.commons import init_weights, get_padding
14
- from infer_pack.transforms import piecewise_rational_quadratic_transform
15
-
16
-
17
- LRELU_SLOPE = 0.1
18
-
19
-
20
- class LayerNorm(nn.Module):
21
- def __init__(self, channels, eps=1e-5):
22
- super().__init__()
23
- self.channels = channels
24
- self.eps = eps
25
-
26
- self.gamma = nn.Parameter(torch.ones(channels))
27
- self.beta = nn.Parameter(torch.zeros(channels))
28
-
29
- def forward(self, x):
30
- x = x.transpose(1, -1)
31
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
- return x.transpose(1, -1)
33
-
34
-
35
- class ConvReluNorm(nn.Module):
36
- def __init__(
37
- self,
38
- in_channels,
39
- hidden_channels,
40
- out_channels,
41
- kernel_size,
42
- n_layers,
43
- p_dropout,
44
- ):
45
- super().__init__()
46
- self.in_channels = in_channels
47
- self.hidden_channels = hidden_channels
48
- self.out_channels = out_channels
49
- self.kernel_size = kernel_size
50
- self.n_layers = n_layers
51
- self.p_dropout = p_dropout
52
- assert n_layers > 1, "Number of layers should be larger than 0."
53
-
54
- self.conv_layers = nn.ModuleList()
55
- self.norm_layers = nn.ModuleList()
56
- self.conv_layers.append(
57
- nn.Conv1d(
58
- in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
- )
60
- )
61
- self.norm_layers.append(LayerNorm(hidden_channels))
62
- self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
- for _ in range(n_layers - 1):
64
- self.conv_layers.append(
65
- nn.Conv1d(
66
- hidden_channels,
67
- hidden_channels,
68
- kernel_size,
69
- padding=kernel_size // 2,
70
- )
71
- )
72
- self.norm_layers.append(LayerNorm(hidden_channels))
73
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
- self.proj.weight.data.zero_()
75
- self.proj.bias.data.zero_()
76
-
77
- def forward(self, x, x_mask):
78
- x_org = x
79
- for i in range(self.n_layers):
80
- x = self.conv_layers[i](x * x_mask)
81
- x = self.norm_layers[i](x)
82
- x = self.relu_drop(x)
83
- x = x_org + self.proj(x)
84
- return x * x_mask
85
-
86
-
87
- class DDSConv(nn.Module):
88
- """
89
- Dialted and Depth-Separable Convolution
90
- """
91
-
92
- def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
- super().__init__()
94
- self.channels = channels
95
- self.kernel_size = kernel_size
96
- self.n_layers = n_layers
97
- self.p_dropout = p_dropout
98
-
99
- self.drop = nn.Dropout(p_dropout)
100
- self.convs_sep = nn.ModuleList()
101
- self.convs_1x1 = nn.ModuleList()
102
- self.norms_1 = nn.ModuleList()
103
- self.norms_2 = nn.ModuleList()
104
- for i in range(n_layers):
105
- dilation = kernel_size**i
106
- padding = (kernel_size * dilation - dilation) // 2
107
- self.convs_sep.append(
108
- nn.Conv1d(
109
- channels,
110
- channels,
111
- kernel_size,
112
- groups=channels,
113
- dilation=dilation,
114
- padding=padding,
115
- )
116
- )
117
- self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
- self.norms_1.append(LayerNorm(channels))
119
- self.norms_2.append(LayerNorm(channels))
120
-
121
- def forward(self, x, x_mask, g=None):
122
- if g is not None:
123
- x = x + g
124
- for i in range(self.n_layers):
125
- y = self.convs_sep[i](x * x_mask)
126
- y = self.norms_1[i](y)
127
- y = F.gelu(y)
128
- y = self.convs_1x1[i](y)
129
- y = self.norms_2[i](y)
130
- y = F.gelu(y)
131
- y = self.drop(y)
132
- x = x + y
133
- return x * x_mask
134
-
135
-
136
- class WN(torch.nn.Module):
137
- def __init__(
138
- self,
139
- hidden_channels,
140
- kernel_size,
141
- dilation_rate,
142
- n_layers,
143
- gin_channels=0,
144
- p_dropout=0,
145
- ):
146
- super(WN, self).__init__()
147
- assert kernel_size % 2 == 1
148
- self.hidden_channels = hidden_channels
149
- self.kernel_size = (kernel_size,)
150
- self.dilation_rate = dilation_rate
151
- self.n_layers = n_layers
152
- self.gin_channels = gin_channels
153
- self.p_dropout = p_dropout
154
-
155
- self.in_layers = torch.nn.ModuleList()
156
- self.res_skip_layers = torch.nn.ModuleList()
157
- self.drop = nn.Dropout(p_dropout)
158
-
159
- if gin_channels != 0:
160
- cond_layer = torch.nn.Conv1d(
161
- gin_channels, 2 * hidden_channels * n_layers, 1
162
- )
163
- self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
-
165
- for i in range(n_layers):
166
- dilation = dilation_rate**i
167
- padding = int((kernel_size * dilation - dilation) / 2)
168
- in_layer = torch.nn.Conv1d(
169
- hidden_channels,
170
- 2 * hidden_channels,
171
- kernel_size,
172
- dilation=dilation,
173
- padding=padding,
174
- )
175
- in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
- self.in_layers.append(in_layer)
177
-
178
- # last one is not necessary
179
- if i < n_layers - 1:
180
- res_skip_channels = 2 * hidden_channels
181
- else:
182
- res_skip_channels = hidden_channels
183
-
184
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
- res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
- self.res_skip_layers.append(res_skip_layer)
187
-
188
- def forward(self, x, x_mask, g=None, **kwargs):
189
- output = torch.zeros_like(x)
190
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
191
-
192
- if g is not None:
193
- g = self.cond_layer(g)
194
-
195
- for i in range(self.n_layers):
196
- x_in = self.in_layers[i](x)
197
- if g is not None:
198
- cond_offset = i * 2 * self.hidden_channels
199
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
200
- else:
201
- g_l = torch.zeros_like(x_in)
202
-
203
- acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
204
- acts = self.drop(acts)
205
-
206
- res_skip_acts = self.res_skip_layers[i](acts)
207
- if i < self.n_layers - 1:
208
- res_acts = res_skip_acts[:, : self.hidden_channels, :]
209
- x = (x + res_acts) * x_mask
210
- output = output + res_skip_acts[:, self.hidden_channels :, :]
211
- else:
212
- output = output + res_skip_acts
213
- return output * x_mask
214
-
215
- def remove_weight_norm(self):
216
- if self.gin_channels != 0:
217
- torch.nn.utils.remove_weight_norm(self.cond_layer)
218
- for l in self.in_layers:
219
- torch.nn.utils.remove_weight_norm(l)
220
- for l in self.res_skip_layers:
221
- torch.nn.utils.remove_weight_norm(l)
222
-
223
-
224
- class ResBlock1(torch.nn.Module):
225
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
226
- super(ResBlock1, self).__init__()
227
- self.convs1 = nn.ModuleList(
228
- [
229
- weight_norm(
230
- Conv1d(
231
- channels,
232
- channels,
233
- kernel_size,
234
- 1,
235
- dilation=dilation[0],
236
- padding=get_padding(kernel_size, dilation[0]),
237
- )
238
- ),
239
- weight_norm(
240
- Conv1d(
241
- channels,
242
- channels,
243
- kernel_size,
244
- 1,
245
- dilation=dilation[1],
246
- padding=get_padding(kernel_size, dilation[1]),
247
- )
248
- ),
249
- weight_norm(
250
- Conv1d(
251
- channels,
252
- channels,
253
- kernel_size,
254
- 1,
255
- dilation=dilation[2],
256
- padding=get_padding(kernel_size, dilation[2]),
257
- )
258
- ),
259
- ]
260
- )
261
- self.convs1.apply(init_weights)
262
-
263
- self.convs2 = nn.ModuleList(
264
- [
265
- weight_norm(
266
- Conv1d(
267
- channels,
268
- channels,
269
- kernel_size,
270
- 1,
271
- dilation=1,
272
- padding=get_padding(kernel_size, 1),
273
- )
274
- ),
275
- weight_norm(
276
- Conv1d(
277
- channels,
278
- channels,
279
- kernel_size,
280
- 1,
281
- dilation=1,
282
- padding=get_padding(kernel_size, 1),
283
- )
284
- ),
285
- weight_norm(
286
- Conv1d(
287
- channels,
288
- channels,
289
- kernel_size,
290
- 1,
291
- dilation=1,
292
- padding=get_padding(kernel_size, 1),
293
- )
294
- ),
295
- ]
296
- )
297
- self.convs2.apply(init_weights)
298
-
299
- def forward(self, x, x_mask=None):
300
- for c1, c2 in zip(self.convs1, self.convs2):
301
- xt = F.leaky_relu(x, LRELU_SLOPE)
302
- if x_mask is not None:
303
- xt = xt * x_mask
304
- xt = c1(xt)
305
- xt = F.leaky_relu(xt, LRELU_SLOPE)
306
- if x_mask is not None:
307
- xt = xt * x_mask
308
- xt = c2(xt)
309
- x = xt + x
310
- if x_mask is not None:
311
- x = x * x_mask
312
- return x
313
-
314
- def remove_weight_norm(self):
315
- for l in self.convs1:
316
- remove_weight_norm(l)
317
- for l in self.convs2:
318
- remove_weight_norm(l)
319
-
320
-
321
- class ResBlock2(torch.nn.Module):
322
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
323
- super(ResBlock2, self).__init__()
324
- self.convs = nn.ModuleList(
325
- [
326
- weight_norm(
327
- Conv1d(
328
- channels,
329
- channels,
330
- kernel_size,
331
- 1,
332
- dilation=dilation[0],
333
- padding=get_padding(kernel_size, dilation[0]),
334
- )
335
- ),
336
- weight_norm(
337
- Conv1d(
338
- channels,
339
- channels,
340
- kernel_size,
341
- 1,
342
- dilation=dilation[1],
343
- padding=get_padding(kernel_size, dilation[1]),
344
- )
345
- ),
346
- ]
347
- )
348
- self.convs.apply(init_weights)
349
-
350
- def forward(self, x, x_mask=None):
351
- for c in self.convs:
352
- xt = F.leaky_relu(x, LRELU_SLOPE)
353
- if x_mask is not None:
354
- xt = xt * x_mask
355
- xt = c(xt)
356
- x = xt + x
357
- if x_mask is not None:
358
- x = x * x_mask
359
- return x
360
-
361
- def remove_weight_norm(self):
362
- for l in self.convs:
363
- remove_weight_norm(l)
364
-
365
-
366
- class Log(nn.Module):
367
- def forward(self, x, x_mask, reverse=False, **kwargs):
368
- if not reverse:
369
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
370
- logdet = torch.sum(-y, [1, 2])
371
- return y, logdet
372
- else:
373
- x = torch.exp(x) * x_mask
374
- return x
375
-
376
-
377
- class Flip(nn.Module):
378
- def forward(self, x, *args, reverse=False, **kwargs):
379
- x = torch.flip(x, [1])
380
- if not reverse:
381
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
382
- return x, logdet
383
- else:
384
- return x
385
-
386
-
387
- class ElementwiseAffine(nn.Module):
388
- def __init__(self, channels):
389
- super().__init__()
390
- self.channels = channels
391
- self.m = nn.Parameter(torch.zeros(channels, 1))
392
- self.logs = nn.Parameter(torch.zeros(channels, 1))
393
-
394
- def forward(self, x, x_mask, reverse=False, **kwargs):
395
- if not reverse:
396
- y = self.m + torch.exp(self.logs) * x
397
- y = y * x_mask
398
- logdet = torch.sum(self.logs * x_mask, [1, 2])
399
- return y, logdet
400
- else:
401
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
402
- return x
403
-
404
-
405
- class ResidualCouplingLayer(nn.Module):
406
- def __init__(
407
- self,
408
- channels,
409
- hidden_channels,
410
- kernel_size,
411
- dilation_rate,
412
- n_layers,
413
- p_dropout=0,
414
- gin_channels=0,
415
- mean_only=False,
416
- ):
417
- assert channels % 2 == 0, "channels should be divisible by 2"
418
- super().__init__()
419
- self.channels = channels
420
- self.hidden_channels = hidden_channels
421
- self.kernel_size = kernel_size
422
- self.dilation_rate = dilation_rate
423
- self.n_layers = n_layers
424
- self.half_channels = channels // 2
425
- self.mean_only = mean_only
426
-
427
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
428
- self.enc = WN(
429
- hidden_channels,
430
- kernel_size,
431
- dilation_rate,
432
- n_layers,
433
- p_dropout=p_dropout,
434
- gin_channels=gin_channels,
435
- )
436
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
437
- self.post.weight.data.zero_()
438
- self.post.bias.data.zero_()
439
-
440
- def forward(self, x, x_mask, g=None, reverse=False):
441
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
442
- h = self.pre(x0) * x_mask
443
- h = self.enc(h, x_mask, g=g)
444
- stats = self.post(h) * x_mask
445
- if not self.mean_only:
446
- m, logs = torch.split(stats, [self.half_channels] * 2, 1)
447
- else:
448
- m = stats
449
- logs = torch.zeros_like(m)
450
-
451
- if not reverse:
452
- x1 = m + x1 * torch.exp(logs) * x_mask
453
- x = torch.cat([x0, x1], 1)
454
- logdet = torch.sum(logs, [1, 2])
455
- return x, logdet
456
- else:
457
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
458
- x = torch.cat([x0, x1], 1)
459
- return x
460
-
461
- def remove_weight_norm(self):
462
- self.enc.remove_weight_norm()
463
-
464
-
465
- class ConvFlow(nn.Module):
466
- def __init__(
467
- self,
468
- in_channels,
469
- filter_channels,
470
- kernel_size,
471
- n_layers,
472
- num_bins=10,
473
- tail_bound=5.0,
474
- ):
475
- super().__init__()
476
- self.in_channels = in_channels
477
- self.filter_channels = filter_channels
478
- self.kernel_size = kernel_size
479
- self.n_layers = n_layers
480
- self.num_bins = num_bins
481
- self.tail_bound = tail_bound
482
- self.half_channels = in_channels // 2
483
-
484
- self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
485
- self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
486
- self.proj = nn.Conv1d(
487
- filter_channels, self.half_channels * (num_bins * 3 - 1), 1
488
- )
489
- self.proj.weight.data.zero_()
490
- self.proj.bias.data.zero_()
491
-
492
- def forward(self, x, x_mask, g=None, reverse=False):
493
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
494
- h = self.pre(x0)
495
- h = self.convs(h, x_mask, g=g)
496
- h = self.proj(h) * x_mask
497
-
498
- b, c, t = x0.shape
499
- h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
500
-
501
- unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
502
- unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
503
- self.filter_channels
504
- )
505
- unnormalized_derivatives = h[..., 2 * self.num_bins :]
506
-
507
- x1, logabsdet = piecewise_rational_quadratic_transform(
508
- x1,
509
- unnormalized_widths,
510
- unnormalized_heights,
511
- unnormalized_derivatives,
512
- inverse=reverse,
513
- tails="linear",
514
- tail_bound=self.tail_bound,
515
- )
516
-
517
- x = torch.cat([x0, x1], 1) * x_mask
518
- logdet = torch.sum(logabsdet * x_mask, [1, 2])
519
- if not reverse:
520
- return x, logdet
521
- else:
522
- return x
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ from infer_pack import commons
13
+ from infer_pack.commons import init_weights, get_padding
14
+ from infer_pack.transforms import piecewise_rational_quadratic_transform
15
+
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ class LayerNorm(nn.Module):
21
+ def __init__(self, channels, eps=1e-5):
22
+ super().__init__()
23
+ self.channels = channels
24
+ self.eps = eps
25
+
26
+ self.gamma = nn.Parameter(torch.ones(channels))
27
+ self.beta = nn.Parameter(torch.zeros(channels))
28
+
29
+ def forward(self, x):
30
+ x = x.transpose(1, -1)
31
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
+ return x.transpose(1, -1)
33
+
34
+
35
+ class ConvReluNorm(nn.Module):
36
+ def __init__(
37
+ self,
38
+ in_channels,
39
+ hidden_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ n_layers,
43
+ p_dropout,
44
+ ):
45
+ super().__init__()
46
+ self.in_channels = in_channels
47
+ self.hidden_channels = hidden_channels
48
+ self.out_channels = out_channels
49
+ self.kernel_size = kernel_size
50
+ self.n_layers = n_layers
51
+ self.p_dropout = p_dropout
52
+ assert n_layers > 1, "Number of layers should be larger than 0."
53
+
54
+ self.conv_layers = nn.ModuleList()
55
+ self.norm_layers = nn.ModuleList()
56
+ self.conv_layers.append(
57
+ nn.Conv1d(
58
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
+ )
60
+ )
61
+ self.norm_layers.append(LayerNorm(hidden_channels))
62
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
+ for _ in range(n_layers - 1):
64
+ self.conv_layers.append(
65
+ nn.Conv1d(
66
+ hidden_channels,
67
+ hidden_channels,
68
+ kernel_size,
69
+ padding=kernel_size // 2,
70
+ )
71
+ )
72
+ self.norm_layers.append(LayerNorm(hidden_channels))
73
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
+ self.proj.weight.data.zero_()
75
+ self.proj.bias.data.zero_()
76
+
77
+ def forward(self, x, x_mask):
78
+ x_org = x
79
+ for i in range(self.n_layers):
80
+ x = self.conv_layers[i](x * x_mask)
81
+ x = self.norm_layers[i](x)
82
+ x = self.relu_drop(x)
83
+ x = x_org + self.proj(x)
84
+ return x * x_mask
85
+
86
+
87
+ class DDSConv(nn.Module):
88
+ """
89
+ Dialted and Depth-Separable Convolution
90
+ """
91
+
92
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.kernel_size = kernel_size
96
+ self.n_layers = n_layers
97
+ self.p_dropout = p_dropout
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.convs_sep = nn.ModuleList()
101
+ self.convs_1x1 = nn.ModuleList()
102
+ self.norms_1 = nn.ModuleList()
103
+ self.norms_2 = nn.ModuleList()
104
+ for i in range(n_layers):
105
+ dilation = kernel_size**i
106
+ padding = (kernel_size * dilation - dilation) // 2
107
+ self.convs_sep.append(
108
+ nn.Conv1d(
109
+ channels,
110
+ channels,
111
+ kernel_size,
112
+ groups=channels,
113
+ dilation=dilation,
114
+ padding=padding,
115
+ )
116
+ )
117
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
+ self.norms_1.append(LayerNorm(channels))
119
+ self.norms_2.append(LayerNorm(channels))
120
+
121
+ def forward(self, x, x_mask, g=None):
122
+ if g is not None:
123
+ x = x + g
124
+ for i in range(self.n_layers):
125
+ y = self.convs_sep[i](x * x_mask)
126
+ y = self.norms_1[i](y)
127
+ y = F.gelu(y)
128
+ y = self.convs_1x1[i](y)
129
+ y = self.norms_2[i](y)
130
+ y = F.gelu(y)
131
+ y = self.drop(y)
132
+ x = x + y
133
+ return x * x_mask
134
+
135
+
136
+ class WN(torch.nn.Module):
137
+ def __init__(
138
+ self,
139
+ hidden_channels,
140
+ kernel_size,
141
+ dilation_rate,
142
+ n_layers,
143
+ gin_channels=0,
144
+ p_dropout=0,
145
+ ):
146
+ super(WN, self).__init__()
147
+ assert kernel_size % 2 == 1
148
+ self.hidden_channels = hidden_channels
149
+ self.kernel_size = (kernel_size,)
150
+ self.dilation_rate = dilation_rate
151
+ self.n_layers = n_layers
152
+ self.gin_channels = gin_channels
153
+ self.p_dropout = p_dropout
154
+
155
+ self.in_layers = torch.nn.ModuleList()
156
+ self.res_skip_layers = torch.nn.ModuleList()
157
+ self.drop = nn.Dropout(p_dropout)
158
+
159
+ if gin_channels != 0:
160
+ cond_layer = torch.nn.Conv1d(
161
+ gin_channels, 2 * hidden_channels * n_layers, 1
162
+ )
163
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
+
165
+ for i in range(n_layers):
166
+ dilation = dilation_rate**i
167
+ padding = int((kernel_size * dilation - dilation) / 2)
168
+ in_layer = torch.nn.Conv1d(
169
+ hidden_channels,
170
+ 2 * hidden_channels,
171
+ kernel_size,
172
+ dilation=dilation,
173
+ padding=padding,
174
+ )
175
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
+ self.in_layers.append(in_layer)
177
+
178
+ # last one is not necessary
179
+ if i < n_layers - 1:
180
+ res_skip_channels = 2 * hidden_channels
181
+ else:
182
+ res_skip_channels = hidden_channels
183
+
184
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
+ self.res_skip_layers.append(res_skip_layer)
187
+
188
+ def forward(self, x, x_mask, g=None, **kwargs):
189
+ output = torch.zeros_like(x)
190
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
191
+
192
+ if g is not None:
193
+ g = self.cond_layer(g)
194
+
195
+ for i in range(self.n_layers):
196
+ x_in = self.in_layers[i](x)
197
+ if g is not None:
198
+ cond_offset = i * 2 * self.hidden_channels
199
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
200
+ else:
201
+ g_l = torch.zeros_like(x_in)
202
+
203
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
204
+ acts = self.drop(acts)
205
+
206
+ res_skip_acts = self.res_skip_layers[i](acts)
207
+ if i < self.n_layers - 1:
208
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
209
+ x = (x + res_acts) * x_mask
210
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
211
+ else:
212
+ output = output + res_skip_acts
213
+ return output * x_mask
214
+
215
+ def remove_weight_norm(self):
216
+ if self.gin_channels != 0:
217
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
218
+ for l in self.in_layers:
219
+ torch.nn.utils.remove_weight_norm(l)
220
+ for l in self.res_skip_layers:
221
+ torch.nn.utils.remove_weight_norm(l)
222
+
223
+
224
+ class ResBlock1(torch.nn.Module):
225
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
226
+ super(ResBlock1, self).__init__()
227
+ self.convs1 = nn.ModuleList(
228
+ [
229
+ weight_norm(
230
+ Conv1d(
231
+ channels,
232
+ channels,
233
+ kernel_size,
234
+ 1,
235
+ dilation=dilation[0],
236
+ padding=get_padding(kernel_size, dilation[0]),
237
+ )
238
+ ),
239
+ weight_norm(
240
+ Conv1d(
241
+ channels,
242
+ channels,
243
+ kernel_size,
244
+ 1,
245
+ dilation=dilation[1],
246
+ padding=get_padding(kernel_size, dilation[1]),
247
+ )
248
+ ),
249
+ weight_norm(
250
+ Conv1d(
251
+ channels,
252
+ channels,
253
+ kernel_size,
254
+ 1,
255
+ dilation=dilation[2],
256
+ padding=get_padding(kernel_size, dilation[2]),
257
+ )
258
+ ),
259
+ ]
260
+ )
261
+ self.convs1.apply(init_weights)
262
+
263
+ self.convs2 = nn.ModuleList(
264
+ [
265
+ weight_norm(
266
+ Conv1d(
267
+ channels,
268
+ channels,
269
+ kernel_size,
270
+ 1,
271
+ dilation=1,
272
+ padding=get_padding(kernel_size, 1),
273
+ )
274
+ ),
275
+ weight_norm(
276
+ Conv1d(
277
+ channels,
278
+ channels,
279
+ kernel_size,
280
+ 1,
281
+ dilation=1,
282
+ padding=get_padding(kernel_size, 1),
283
+ )
284
+ ),
285
+ weight_norm(
286
+ Conv1d(
287
+ channels,
288
+ channels,
289
+ kernel_size,
290
+ 1,
291
+ dilation=1,
292
+ padding=get_padding(kernel_size, 1),
293
+ )
294
+ ),
295
+ ]
296
+ )
297
+ self.convs2.apply(init_weights)
298
+
299
+ def forward(self, x, x_mask=None):
300
+ for c1, c2 in zip(self.convs1, self.convs2):
301
+ xt = F.leaky_relu(x, LRELU_SLOPE)
302
+ if x_mask is not None:
303
+ xt = xt * x_mask
304
+ xt = c1(xt)
305
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
306
+ if x_mask is not None:
307
+ xt = xt * x_mask
308
+ xt = c2(xt)
309
+ x = xt + x
310
+ if x_mask is not None:
311
+ x = x * x_mask
312
+ return x
313
+
314
+ def remove_weight_norm(self):
315
+ for l in self.convs1:
316
+ remove_weight_norm(l)
317
+ for l in self.convs2:
318
+ remove_weight_norm(l)
319
+
320
+
321
+ class ResBlock2(torch.nn.Module):
322
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
323
+ super(ResBlock2, self).__init__()
324
+ self.convs = nn.ModuleList(
325
+ [
326
+ weight_norm(
327
+ Conv1d(
328
+ channels,
329
+ channels,
330
+ kernel_size,
331
+ 1,
332
+ dilation=dilation[0],
333
+ padding=get_padding(kernel_size, dilation[0]),
334
+ )
335
+ ),
336
+ weight_norm(
337
+ Conv1d(
338
+ channels,
339
+ channels,
340
+ kernel_size,
341
+ 1,
342
+ dilation=dilation[1],
343
+ padding=get_padding(kernel_size, dilation[1]),
344
+ )
345
+ ),
346
+ ]
347
+ )
348
+ self.convs.apply(init_weights)
349
+
350
+ def forward(self, x, x_mask=None):
351
+ for c in self.convs:
352
+ xt = F.leaky_relu(x, LRELU_SLOPE)
353
+ if x_mask is not None:
354
+ xt = xt * x_mask
355
+ xt = c(xt)
356
+ x = xt + x
357
+ if x_mask is not None:
358
+ x = x * x_mask
359
+ return x
360
+
361
+ def remove_weight_norm(self):
362
+ for l in self.convs:
363
+ remove_weight_norm(l)
364
+
365
+
366
+ class Log(nn.Module):
367
+ def forward(self, x, x_mask, reverse=False, **kwargs):
368
+ if not reverse:
369
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
370
+ logdet = torch.sum(-y, [1, 2])
371
+ return y, logdet
372
+ else:
373
+ x = torch.exp(x) * x_mask
374
+ return x
375
+
376
+
377
+ class Flip(nn.Module):
378
+ def forward(self, x, *args, reverse=False, **kwargs):
379
+ x = torch.flip(x, [1])
380
+ if not reverse:
381
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
382
+ return x, logdet
383
+ else:
384
+ return x
385
+
386
+
387
+ class ElementwiseAffine(nn.Module):
388
+ def __init__(self, channels):
389
+ super().__init__()
390
+ self.channels = channels
391
+ self.m = nn.Parameter(torch.zeros(channels, 1))
392
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
393
+
394
+ def forward(self, x, x_mask, reverse=False, **kwargs):
395
+ if not reverse:
396
+ y = self.m + torch.exp(self.logs) * x
397
+ y = y * x_mask
398
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
399
+ return y, logdet
400
+ else:
401
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
402
+ return x
403
+
404
+
405
+ class ResidualCouplingLayer(nn.Module):
406
+ def __init__(
407
+ self,
408
+ channels,
409
+ hidden_channels,
410
+ kernel_size,
411
+ dilation_rate,
412
+ n_layers,
413
+ p_dropout=0,
414
+ gin_channels=0,
415
+ mean_only=False,
416
+ ):
417
+ assert channels % 2 == 0, "channels should be divisible by 2"
418
+ super().__init__()
419
+ self.channels = channels
420
+ self.hidden_channels = hidden_channels
421
+ self.kernel_size = kernel_size
422
+ self.dilation_rate = dilation_rate
423
+ self.n_layers = n_layers
424
+ self.half_channels = channels // 2
425
+ self.mean_only = mean_only
426
+
427
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
428
+ self.enc = WN(
429
+ hidden_channels,
430
+ kernel_size,
431
+ dilation_rate,
432
+ n_layers,
433
+ p_dropout=p_dropout,
434
+ gin_channels=gin_channels,
435
+ )
436
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
437
+ self.post.weight.data.zero_()
438
+ self.post.bias.data.zero_()
439
+
440
+ def forward(self, x, x_mask, g=None, reverse=False):
441
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
442
+ h = self.pre(x0) * x_mask
443
+ h = self.enc(h, x_mask, g=g)
444
+ stats = self.post(h) * x_mask
445
+ if not self.mean_only:
446
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
447
+ else:
448
+ m = stats
449
+ logs = torch.zeros_like(m)
450
+
451
+ if not reverse:
452
+ x1 = m + x1 * torch.exp(logs) * x_mask
453
+ x = torch.cat([x0, x1], 1)
454
+ logdet = torch.sum(logs, [1, 2])
455
+ return x, logdet
456
+ else:
457
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
458
+ x = torch.cat([x0, x1], 1)
459
+ return x
460
+
461
+ def remove_weight_norm(self):
462
+ self.enc.remove_weight_norm()
463
+
464
+
465
+ class ConvFlow(nn.Module):
466
+ def __init__(
467
+ self,
468
+ in_channels,
469
+ filter_channels,
470
+ kernel_size,
471
+ n_layers,
472
+ num_bins=10,
473
+ tail_bound=5.0,
474
+ ):
475
+ super().__init__()
476
+ self.in_channels = in_channels
477
+ self.filter_channels = filter_channels
478
+ self.kernel_size = kernel_size
479
+ self.n_layers = n_layers
480
+ self.num_bins = num_bins
481
+ self.tail_bound = tail_bound
482
+ self.half_channels = in_channels // 2
483
+
484
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
485
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
486
+ self.proj = nn.Conv1d(
487
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
488
+ )
489
+ self.proj.weight.data.zero_()
490
+ self.proj.bias.data.zero_()
491
+
492
+ def forward(self, x, x_mask, g=None, reverse=False):
493
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
494
+ h = self.pre(x0)
495
+ h = self.convs(h, x_mask, g=g)
496
+ h = self.proj(h) * x_mask
497
+
498
+ b, c, t = x0.shape
499
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
500
+
501
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
502
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
503
+ self.filter_channels
504
+ )
505
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
506
+
507
+ x1, logabsdet = piecewise_rational_quadratic_transform(
508
+ x1,
509
+ unnormalized_widths,
510
+ unnormalized_heights,
511
+ unnormalized_derivatives,
512
+ inverse=reverse,
513
+ tails="linear",
514
+ tail_bound=self.tail_bound,
515
+ )
516
+
517
+ x = torch.cat([x0, x1], 1) * x_mask
518
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
519
+ if not reverse:
520
+ return x, logdet
521
+ else:
522
+ return x
infer_pack/transforms.py CHANGED
@@ -1,193 +1,209 @@
1
- import torch
2
- from torch.nn import functional as F
3
-
4
- import numpy as np
5
-
6
-
7
- DEFAULT_MIN_BIN_WIDTH = 1e-3
8
- DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
- DEFAULT_MIN_DERIVATIVE = 1e-3
10
-
11
-
12
- def piecewise_rational_quadratic_transform(inputs,
13
- unnormalized_widths,
14
- unnormalized_heights,
15
- unnormalized_derivatives,
16
- inverse=False,
17
- tails=None,
18
- tail_bound=1.,
19
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
20
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
21
- min_derivative=DEFAULT_MIN_DERIVATIVE):
22
-
23
- if tails is None:
24
- spline_fn = rational_quadratic_spline
25
- spline_kwargs = {}
26
- else:
27
- spline_fn = unconstrained_rational_quadratic_spline
28
- spline_kwargs = {
29
- 'tails': tails,
30
- 'tail_bound': tail_bound
31
- }
32
-
33
- outputs, logabsdet = spline_fn(
34
- inputs=inputs,
35
- unnormalized_widths=unnormalized_widths,
36
- unnormalized_heights=unnormalized_heights,
37
- unnormalized_derivatives=unnormalized_derivatives,
38
- inverse=inverse,
39
- min_bin_width=min_bin_width,
40
- min_bin_height=min_bin_height,
41
- min_derivative=min_derivative,
42
- **spline_kwargs
43
- )
44
- return outputs, logabsdet
45
-
46
-
47
- def searchsorted(bin_locations, inputs, eps=1e-6):
48
- bin_locations[..., -1] += eps
49
- return torch.sum(
50
- inputs[..., None] >= bin_locations,
51
- dim=-1
52
- ) - 1
53
-
54
-
55
- def unconstrained_rational_quadratic_spline(inputs,
56
- unnormalized_widths,
57
- unnormalized_heights,
58
- unnormalized_derivatives,
59
- inverse=False,
60
- tails='linear',
61
- tail_bound=1.,
62
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
63
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
64
- min_derivative=DEFAULT_MIN_DERIVATIVE):
65
- inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66
- outside_interval_mask = ~inside_interval_mask
67
-
68
- outputs = torch.zeros_like(inputs)
69
- logabsdet = torch.zeros_like(inputs)
70
-
71
- if tails == 'linear':
72
- unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73
- constant = np.log(np.exp(1 - min_derivative) - 1)
74
- unnormalized_derivatives[..., 0] = constant
75
- unnormalized_derivatives[..., -1] = constant
76
-
77
- outputs[outside_interval_mask] = inputs[outside_interval_mask]
78
- logabsdet[outside_interval_mask] = 0
79
- else:
80
- raise RuntimeError('{} tails are not implemented.'.format(tails))
81
-
82
- outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83
- inputs=inputs[inside_interval_mask],
84
- unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
- unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
- unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
- inverse=inverse,
88
- left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
89
- min_bin_width=min_bin_width,
90
- min_bin_height=min_bin_height,
91
- min_derivative=min_derivative
92
- )
93
-
94
- return outputs, logabsdet
95
-
96
- def rational_quadratic_spline(inputs,
97
- unnormalized_widths,
98
- unnormalized_heights,
99
- unnormalized_derivatives,
100
- inverse=False,
101
- left=0., right=1., bottom=0., top=1.,
102
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
103
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
104
- min_derivative=DEFAULT_MIN_DERIVATIVE):
105
- if torch.min(inputs) < left or torch.max(inputs) > right:
106
- raise ValueError('Input to a transform is not within its domain')
107
-
108
- num_bins = unnormalized_widths.shape[-1]
109
-
110
- if min_bin_width * num_bins > 1.0:
111
- raise ValueError('Minimal bin width too large for the number of bins')
112
- if min_bin_height * num_bins > 1.0:
113
- raise ValueError('Minimal bin height too large for the number of bins')
114
-
115
- widths = F.softmax(unnormalized_widths, dim=-1)
116
- widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
117
- cumwidths = torch.cumsum(widths, dim=-1)
118
- cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
119
- cumwidths = (right - left) * cumwidths + left
120
- cumwidths[..., 0] = left
121
- cumwidths[..., -1] = right
122
- widths = cumwidths[..., 1:] - cumwidths[..., :-1]
123
-
124
- derivatives = min_derivative + F.softplus(unnormalized_derivatives)
125
-
126
- heights = F.softmax(unnormalized_heights, dim=-1)
127
- heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
128
- cumheights = torch.cumsum(heights, dim=-1)
129
- cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
130
- cumheights = (top - bottom) * cumheights + bottom
131
- cumheights[..., 0] = bottom
132
- cumheights[..., -1] = top
133
- heights = cumheights[..., 1:] - cumheights[..., :-1]
134
-
135
- if inverse:
136
- bin_idx = searchsorted(cumheights, inputs)[..., None]
137
- else:
138
- bin_idx = searchsorted(cumwidths, inputs)[..., None]
139
-
140
- input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
141
- input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
142
-
143
- input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
144
- delta = heights / widths
145
- input_delta = delta.gather(-1, bin_idx)[..., 0]
146
-
147
- input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
148
- input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
149
-
150
- input_heights = heights.gather(-1, bin_idx)[..., 0]
151
-
152
- if inverse:
153
- a = (((inputs - input_cumheights) * (input_derivatives
154
- + input_derivatives_plus_one
155
- - 2 * input_delta)
156
- + input_heights * (input_delta - input_derivatives)))
157
- b = (input_heights * input_derivatives
158
- - (inputs - input_cumheights) * (input_derivatives
159
- + input_derivatives_plus_one
160
- - 2 * input_delta))
161
- c = - input_delta * (inputs - input_cumheights)
162
-
163
- discriminant = b.pow(2) - 4 * a * c
164
- assert (discriminant >= 0).all()
165
-
166
- root = (2 * c) / (-b - torch.sqrt(discriminant))
167
- outputs = root * input_bin_widths + input_cumwidths
168
-
169
- theta_one_minus_theta = root * (1 - root)
170
- denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
171
- * theta_one_minus_theta)
172
- derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
173
- + 2 * input_delta * theta_one_minus_theta
174
- + input_derivatives * (1 - root).pow(2))
175
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
176
-
177
- return outputs, -logabsdet
178
- else:
179
- theta = (inputs - input_cumwidths) / input_bin_widths
180
- theta_one_minus_theta = theta * (1 - theta)
181
-
182
- numerator = input_heights * (input_delta * theta.pow(2)
183
- + input_derivatives * theta_one_minus_theta)
184
- denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185
- * theta_one_minus_theta)
186
- outputs = input_cumheights + numerator / denominator
187
-
188
- derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
189
- + 2 * input_delta * theta_one_minus_theta
190
- + input_derivatives * (1 - theta).pow(2))
191
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
192
-
193
- return outputs, logabsdet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
requirements.txt CHANGED
@@ -4,8 +4,7 @@ torchaudio
4
  fairseq==0.12.2
5
  scipy==1.9.3
6
  pyworld>=0.3.2
7
- faiss-cpu==1.7.2 ; python_version < "3.11"
8
- faiss-cpu==1.7.3 ; python_version > "3.10"
9
  praat-parselmouth>=0.4.3
10
  librosa==0.9.2
11
  edge-tts
 
4
  fairseq==0.12.2
5
  scipy==1.9.3
6
  pyworld>=0.3.2
7
+ faiss-cpu==1.7.3
 
8
  praat-parselmouth>=0.4.3
9
  librosa==0.9.2
10
  edge-tts
util.py CHANGED
@@ -25,16 +25,22 @@ def has_mps() -> bool:
25
  return False
26
 
27
 
28
- # https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/main/config.py#L58-L71 # noqa
29
  def is_half(device: str) -> bool:
30
- if device == 'cpu':
31
  return False
32
  else:
33
- if has_mps():
34
- return True
35
-
36
- gpu_name = torch.cuda.get_device_name(int(device.split(':')[-1]))
37
- if '16' in gpu_name or 'MX' in gpu_name:
 
 
 
 
 
 
 
38
  return False
39
 
40
  return True
 
25
  return False
26
 
27
 
 
28
  def is_half(device: str) -> bool:
29
+ if not device.startswith('cuda'):
30
  return False
31
  else:
32
+ gpu_name = torch.cuda.get_device_name(
33
+ int(device.split(':')[-1])
34
+ ).upper()
35
+
36
+ # ...regex?
37
+ if (
38
+ ('16' in gpu_name and 'V100' not in gpu_name)
39
+ or 'P40' in gpu_name
40
+ or '1060' in gpu_name
41
+ or '1070' in gpu_name
42
+ or '1080' in gpu_name
43
+ ):
44
  return False
45
 
46
  return True