pdjdev commited on
Commit
85a7d2c
·
1 Parent(s): 798b4d1

add ddsp-svc

Browse files
Files changed (48) hide show
  1. DDSP-SVC/.gitignore +10 -0
  2. DDSP-SVC/LICENSE +21 -0
  3. DDSP-SVC/data/train/.gitignore +3 -0
  4. DDSP-SVC/data/train/audio/.gitignore +2 -0
  5. DDSP-SVC/data/val/.gitignore +3 -0
  6. DDSP-SVC/data/val/audio/.gitignore +2 -0
  7. DDSP-SVC/data_loaders.py +244 -0
  8. DDSP-SVC/ddsp/__init__.py +0 -0
  9. DDSP-SVC/ddsp/core.py +281 -0
  10. DDSP-SVC/ddsp/loss.py +57 -0
  11. DDSP-SVC/ddsp/pcmer.py +380 -0
  12. DDSP-SVC/ddsp/unit2control.py +86 -0
  13. DDSP-SVC/ddsp/vocoder.py +652 -0
  14. DDSP-SVC/diffusion/data_loaders.py +271 -0
  15. DDSP-SVC/diffusion/diffusion.py +317 -0
  16. DDSP-SVC/diffusion/dpm_solver_pytorch.py +1201 -0
  17. DDSP-SVC/diffusion/infer_gt_mel.py +78 -0
  18. DDSP-SVC/diffusion/solver.py +171 -0
  19. DDSP-SVC/diffusion/unit2mel.py +96 -0
  20. DDSP-SVC/diffusion/vocoder.py +87 -0
  21. DDSP-SVC/diffusion/wavenet.py +108 -0
  22. DDSP-SVC/draw.py +102 -0
  23. DDSP-SVC/encoder/hubert/model.py +293 -0
  24. DDSP-SVC/enhancer.py +115 -0
  25. DDSP-SVC/exp/.gitignore +2 -0
  26. DDSP-SVC/flask_api.py +178 -0
  27. DDSP-SVC/gui.py +483 -0
  28. DDSP-SVC/gui_diff.py +576 -0
  29. DDSP-SVC/gui_diff_locale.py +154 -0
  30. DDSP-SVC/gui_locale.py +130 -0
  31. DDSP-SVC/logger/__init__.py +0 -0
  32. DDSP-SVC/logger/saver.py +145 -0
  33. DDSP-SVC/logger/utils.py +122 -0
  34. DDSP-SVC/main.py +282 -0
  35. DDSP-SVC/main_diff.py +372 -0
  36. DDSP-SVC/nsf_hifigan/env.py +15 -0
  37. DDSP-SVC/nsf_hifigan/models.py +430 -0
  38. DDSP-SVC/nsf_hifigan/nvSTFT.py +129 -0
  39. DDSP-SVC/nsf_hifigan/utils.py +68 -0
  40. DDSP-SVC/preprocess.py +197 -0
  41. DDSP-SVC/pretrain/hubert/.gitignore +2 -0
  42. DDSP-SVC/pretrain/nsf_hifigan/.gitignore +2 -0
  43. DDSP-SVC/requirements.txt +25 -0
  44. DDSP-SVC/slicer.py +146 -0
  45. DDSP-SVC/solver.py +151 -0
  46. DDSP-SVC/train.py +93 -0
  47. DDSP-SVC/train_diff.py +70 -0
  48. DDSP-SVC/webui.py +267 -0
DDSP-SVC/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+
3
+ venv/
4
+ results/
5
+ configs/
6
+ !configs/combsub.yaml
7
+ !configs/combsub-old.yaml
8
+ !configs/sins.yaml
9
+ !configs/diffusion.yaml
10
+ cache/
DDSP-SVC/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 yxlllc
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
DDSP-SVC/data/train/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *
2
+ !.gitignore
3
+ !audio
DDSP-SVC/data/train/audio/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
DDSP-SVC/data/val/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *
2
+ !.gitignore
3
+ !audio
DDSP-SVC/data/val/audio/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
DDSP-SVC/data_loaders.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import librosa
5
+ import torch
6
+ import random
7
+ from tqdm import tqdm
8
+ from torch.utils.data import Dataset
9
+
10
+ def traverse_dir(
11
+ root_dir,
12
+ extension,
13
+ amount=None,
14
+ str_include=None,
15
+ str_exclude=None,
16
+ is_pure=False,
17
+ is_sort=False,
18
+ is_ext=True):
19
+
20
+ file_list = []
21
+ cnt = 0
22
+ for root, _, files in os.walk(root_dir):
23
+ for file in files:
24
+ if file.endswith(extension):
25
+ # path
26
+ mix_path = os.path.join(root, file)
27
+ pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
28
+
29
+ # amount
30
+ if (amount is not None) and (cnt == amount):
31
+ if is_sort:
32
+ file_list.sort()
33
+ return file_list
34
+
35
+ # check string
36
+ if (str_include is not None) and (str_include not in pure_path):
37
+ continue
38
+ if (str_exclude is not None) and (str_exclude in pure_path):
39
+ continue
40
+
41
+ if not is_ext:
42
+ ext = pure_path.split('.')[-1]
43
+ pure_path = pure_path[:-(len(ext)+1)]
44
+ file_list.append(pure_path)
45
+ cnt += 1
46
+ if is_sort:
47
+ file_list.sort()
48
+ return file_list
49
+
50
+
51
+ def get_data_loaders(args, whole_audio=False):
52
+ data_train = AudioDataset(
53
+ args.data.train_path,
54
+ waveform_sec=args.data.duration,
55
+ hop_size=args.data.block_size,
56
+ sample_rate=args.data.sampling_rate,
57
+ load_all_data=args.train.cache_all_data,
58
+ whole_audio=whole_audio,
59
+ n_spk=args.model.n_spk,
60
+ device=args.train.cache_device,
61
+ fp16=args.train.cache_fp16,
62
+ use_aug=True)
63
+ loader_train = torch.utils.data.DataLoader(
64
+ data_train ,
65
+ batch_size=args.train.batch_size if not whole_audio else 1,
66
+ shuffle=True,
67
+ num_workers=args.train.num_workers if args.train.cache_device=='cpu' else 0,
68
+ persistent_workers=(args.train.num_workers > 0) if args.train.cache_device=='cpu' else False,
69
+ pin_memory=True if args.train.cache_device=='cpu' else False
70
+ )
71
+ data_valid = AudioDataset(
72
+ args.data.valid_path,
73
+ waveform_sec=args.data.duration,
74
+ hop_size=args.data.block_size,
75
+ sample_rate=args.data.sampling_rate,
76
+ load_all_data=args.train.cache_all_data,
77
+ whole_audio=True,
78
+ n_spk=args.model.n_spk)
79
+ loader_valid = torch.utils.data.DataLoader(
80
+ data_valid,
81
+ batch_size=1,
82
+ shuffle=False,
83
+ num_workers=0,
84
+ pin_memory=True
85
+ )
86
+ return loader_train, loader_valid
87
+
88
+
89
+ class AudioDataset(Dataset):
90
+ def __init__(
91
+ self,
92
+ path_root,
93
+ waveform_sec,
94
+ hop_size,
95
+ sample_rate,
96
+ load_all_data=True,
97
+ whole_audio=False,
98
+ n_spk=1,
99
+ device = 'cpu',
100
+ fp16 = False,
101
+ use_aug = False
102
+ ):
103
+ super().__init__()
104
+
105
+ self.waveform_sec = waveform_sec
106
+ self.sample_rate = sample_rate
107
+ self.hop_size = hop_size
108
+ self.path_root = path_root
109
+ self.paths = traverse_dir(
110
+ os.path.join(path_root, 'audio'),
111
+ extension='wav',
112
+ is_pure=True,
113
+ is_sort=True,
114
+ is_ext=False
115
+ )
116
+ self.whole_audio = whole_audio
117
+ self.use_aug = use_aug
118
+ self.data_buffer={}
119
+ if load_all_data:
120
+ print('Load all the data from :', path_root)
121
+ else:
122
+ print('Load the f0, volume data from :', path_root)
123
+ for name in tqdm(self.paths, total=len(self.paths)):
124
+ path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
125
+ duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate)
126
+
127
+ path_f0 = os.path.join(self.path_root, 'f0', name) + '.npy'
128
+ f0 = np.load(path_f0)
129
+ f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(device)
130
+
131
+ path_volume = os.path.join(self.path_root, 'volume', name) + '.npy'
132
+ volume = np.load(path_volume)
133
+ volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device)
134
+
135
+ if n_spk is not None and n_spk > 1:
136
+ spk_id = int(os.path.dirname(name)) if str.isdigit(os.path.dirname(name)) else 0
137
+ if spk_id < 1 or spk_id > n_spk:
138
+ raise ValueError(' [x] Muiti-speaker traing error : spk_id must be a positive integer from 1 to n_spk ')
139
+ else:
140
+ spk_id = 1
141
+ spk_id = torch.LongTensor(np.array([spk_id])).to(device)
142
+
143
+ if load_all_data:
144
+ audio, sr = librosa.load(path_audio, sr=self.sample_rate)
145
+ if len(audio.shape) > 1:
146
+ audio = librosa.to_mono(audio)
147
+ audio = torch.from_numpy(audio).to(device)
148
+
149
+ path_units = os.path.join(self.path_root, 'units', name) + '.npy'
150
+ units = np.load(path_units)
151
+ units = torch.from_numpy(units).to(device)
152
+
153
+ if fp16:
154
+ audio = audio.half()
155
+ units = units.half()
156
+
157
+ self.data_buffer[name] = {
158
+ 'duration': duration,
159
+ 'audio': audio,
160
+ 'units': units,
161
+ 'f0': f0,
162
+ 'volume': volume,
163
+ 'spk_id': spk_id
164
+ }
165
+ else:
166
+ self.data_buffer[name] = {
167
+ 'duration': duration,
168
+ 'f0': f0,
169
+ 'volume': volume,
170
+ 'spk_id': spk_id
171
+ }
172
+
173
+
174
+ def __getitem__(self, file_idx):
175
+ name = self.paths[file_idx]
176
+ data_buffer = self.data_buffer[name]
177
+ # check duration. if too short, then skip
178
+ if data_buffer['duration'] < (self.waveform_sec + 0.1):
179
+ return self.__getitem__( (file_idx + 1) % len(self.paths))
180
+
181
+ # get item
182
+ return self.get_data(name, data_buffer)
183
+
184
+ def get_data(self, name, data_buffer):
185
+ frame_resolution = self.hop_size / self.sample_rate
186
+ duration = data_buffer['duration']
187
+ waveform_sec = duration if self.whole_audio else self.waveform_sec
188
+
189
+ # load audio
190
+ idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1)
191
+ start_frame = int(idx_from / frame_resolution)
192
+ units_frame_len = int(waveform_sec / frame_resolution)
193
+ audio = data_buffer.get('audio')
194
+ if audio is None:
195
+ path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
196
+ audio, sr = librosa.load(
197
+ path_audio,
198
+ sr = self.sample_rate,
199
+ offset = start_frame * frame_resolution,
200
+ duration = waveform_sec)
201
+ if len(audio.shape) > 1:
202
+ audio = librosa.to_mono(audio)
203
+ # clip audio into N seconds
204
+ audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size]
205
+ audio = torch.from_numpy(audio).float()
206
+ else:
207
+ audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size]
208
+
209
+ # load units
210
+ units = data_buffer.get('units')
211
+ if units is None:
212
+ units = os.path.join(self.path_root, 'units', name) + '.npy'
213
+ units = np.load(units)
214
+ units = units[start_frame : start_frame + units_frame_len]
215
+ units = torch.from_numpy(units).float()
216
+ else:
217
+ units = units[start_frame : start_frame + units_frame_len]
218
+
219
+ # load f0
220
+ f0 = data_buffer.get('f0')
221
+ f0_frames = f0[start_frame : start_frame + units_frame_len]
222
+
223
+ # load volume
224
+ volume = data_buffer.get('volume')
225
+ volume_frames = volume[start_frame : start_frame + units_frame_len]
226
+
227
+ # load spk_id
228
+ spk_id = data_buffer.get('spk_id')
229
+
230
+ # volume augmentation
231
+ if self.use_aug:
232
+ max_amp = float(torch.max(torch.abs(audio))) + 1e-5
233
+ max_shift = min(1, np.log10(1/max_amp))
234
+ log10_vol_shift = random.uniform(-1, max_shift)
235
+ audio_aug = audio * (10 ** log10_vol_shift)
236
+ volume_frames_aug = volume_frames * (10 ** log10_vol_shift)
237
+ else:
238
+ audio_aug = audio
239
+ volume_frames_aug = volume_frames
240
+
241
+ return dict(audio=audio_aug, f0=f0_frames, volume=volume_frames_aug, units=units, spk_id=spk_id, name=name)
242
+
243
+ def __len__(self):
244
+ return len(self.paths)
DDSP-SVC/ddsp/__init__.py ADDED
File without changes
DDSP-SVC/ddsp/core.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ import math
6
+ import numpy as np
7
+
8
+ def MaskedAvgPool1d(x, kernel_size):
9
+ x = x.unsqueeze(1)
10
+ x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
11
+ mask = ~torch.isnan(x)
12
+ masked_x = torch.where(mask, x, torch.zeros_like(x))
13
+ ones_kernel = torch.ones(x.size(1), 1, kernel_size, device=x.device)
14
+
15
+ # Perform sum pooling
16
+ sum_pooled = F.conv1d(
17
+ masked_x,
18
+ ones_kernel,
19
+ stride=1,
20
+ padding=0,
21
+ groups=x.size(1),
22
+ )
23
+
24
+ # Count the non-masked (valid) elements in each pooling window
25
+ valid_count = F.conv1d(
26
+ mask.float(),
27
+ ones_kernel,
28
+ stride=1,
29
+ padding=0,
30
+ groups=x.size(1),
31
+ )
32
+ valid_count = valid_count.clamp(min=1) # Avoid division by zero
33
+
34
+ # Perform masked average pooling
35
+ avg_pooled = sum_pooled / valid_count
36
+
37
+ return avg_pooled.squeeze(1)
38
+
39
+ def MedianPool1d(x, kernel_size):
40
+ x = x.unsqueeze(1)
41
+ x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
42
+ x = x.squeeze(1)
43
+ x = x.unfold(1, kernel_size, 1)
44
+ x, _ = torch.sort(x, dim=-1)
45
+ return x[:, :, (kernel_size - 1) // 2]
46
+
47
+ def get_fft_size(frame_size: int, ir_size: int, power_of_2: bool = True):
48
+ """Calculate final size for efficient FFT.
49
+ Args:
50
+ frame_size: Size of the audio frame.
51
+ ir_size: Size of the convolving impulse response.
52
+ power_of_2: Constrain to be a power of 2. If False, allow other 5-smooth
53
+ numbers. TPU requires power of 2, while GPU is more flexible.
54
+ Returns:
55
+ fft_size: Size for efficient FFT.
56
+ """
57
+ convolved_frame_size = ir_size + frame_size - 1
58
+ if power_of_2:
59
+ # Next power of 2.
60
+ fft_size = int(2**np.ceil(np.log2(convolved_frame_size)))
61
+ else:
62
+ fft_size = convolved_frame_size
63
+ return fft_size
64
+
65
+
66
+ def upsample(signal, factor):
67
+ signal = signal.permute(0, 2, 1)
68
+ signal = nn.functional.interpolate(torch.cat((signal,signal[:,:,-1:]),2), size=signal.shape[-1] * factor + 1, mode='linear', align_corners=True)
69
+ signal = signal[:,:,:-1]
70
+ return signal.permute(0, 2, 1)
71
+
72
+
73
+ def remove_above_fmax(amplitudes, pitch, fmax, level_start=1):
74
+ n_harm = amplitudes.shape[-1]
75
+ pitches = pitch * torch.arange(level_start, n_harm + level_start).to(pitch)
76
+ aa = (pitches < fmax).float() + 1e-7
77
+ return amplitudes * aa
78
+
79
+
80
+ def crop_and_compensate_delay(audio, audio_size, ir_size,
81
+ padding = 'same',
82
+ delay_compensation = -1):
83
+ """Crop audio output from convolution to compensate for group delay.
84
+ Args:
85
+ audio: Audio after convolution. Tensor of shape [batch, time_steps].
86
+ audio_size: Initial size of the audio before convolution.
87
+ ir_size: Size of the convolving impulse response.
88
+ padding: Either 'valid' or 'same'. For 'same' the final output to be the
89
+ same size as the input audio (audio_timesteps). For 'valid' the audio is
90
+ extended to include the tail of the impulse response (audio_timesteps +
91
+ ir_timesteps - 1).
92
+ delay_compensation: Samples to crop from start of output audio to compensate
93
+ for group delay of the impulse response. If delay_compensation < 0 it
94
+ defaults to automatically calculating a constant group delay of the
95
+ windowed linear phase filter from frequency_impulse_response().
96
+ Returns:
97
+ Tensor of cropped and shifted audio.
98
+ Raises:
99
+ ValueError: If padding is not either 'valid' or 'same'.
100
+ """
101
+ # Crop the output.
102
+ if padding == 'valid':
103
+ crop_size = ir_size + audio_size - 1
104
+ elif padding == 'same':
105
+ crop_size = audio_size
106
+ else:
107
+ raise ValueError('Padding must be \'valid\' or \'same\', instead '
108
+ 'of {}.'.format(padding))
109
+
110
+ # Compensate for the group delay of the filter by trimming the front.
111
+ # For an impulse response produced by frequency_impulse_response(),
112
+ # the group delay is constant because the filter is linear phase.
113
+ total_size = int(audio.shape[-1])
114
+ crop = total_size - crop_size
115
+ start = (ir_size // 2 if delay_compensation < 0 else delay_compensation)
116
+ end = crop - start
117
+ return audio[:, start:-end]
118
+
119
+
120
+ def fft_convolve(audio,
121
+ impulse_response): # B, n_frames, 2*(n_mags-1)
122
+ """Filter audio with frames of time-varying impulse responses.
123
+ Time-varying filter. Given audio [batch, n_samples], and a series of impulse
124
+ responses [batch, n_frames, n_impulse_response], splits the audio into frames,
125
+ applies filters, and then overlap-and-adds audio back together.
126
+ Applies non-windowed non-overlapping STFT/ISTFT to efficiently compute
127
+ convolution for large impulse response sizes.
128
+ Args:
129
+ audio: Input audio. Tensor of shape [batch, audio_timesteps].
130
+ impulse_response: Finite impulse response to convolve. Can either be a 2-D
131
+ Tensor of shape [batch, ir_size], or a 3-D Tensor of shape [batch,
132
+ ir_frames, ir_size]. A 2-D tensor will apply a single linear
133
+ time-invariant filter to the audio. A 3-D Tensor will apply a linear
134
+ time-varying filter. Automatically chops the audio into equally shaped
135
+ blocks to match ir_frames.
136
+ Returns:
137
+ audio_out: Convolved audio. Tensor of shape
138
+ [batch, audio_timesteps].
139
+ """
140
+ # Add a frame dimension to impulse response if it doesn't have one.
141
+ ir_shape = impulse_response.size()
142
+ if len(ir_shape) == 2:
143
+ impulse_response = impulse_response.unsqueeze(1)
144
+ ir_shape = impulse_response.size()
145
+
146
+ # Get shapes of audio and impulse response.
147
+ batch_size_ir, n_ir_frames, ir_size = ir_shape
148
+ batch_size, audio_size = audio.size() # B, T
149
+
150
+ # Validate that batch sizes match.
151
+ if batch_size != batch_size_ir:
152
+ raise ValueError('Batch size of audio ({}) and impulse response ({}) must '
153
+ 'be the same.'.format(batch_size, batch_size_ir))
154
+
155
+ # Cut audio into 50% overlapped frames (center padding).
156
+ hop_size = int(audio_size / n_ir_frames)
157
+ frame_size = 2 * hop_size
158
+ audio_frames = F.pad(audio, (hop_size, hop_size)).unfold(1, frame_size, hop_size)
159
+
160
+ # Apply Bartlett (triangular) window
161
+ window = torch.bartlett_window(frame_size).to(audio_frames)
162
+ audio_frames = audio_frames * window
163
+
164
+ # Pad and FFT the audio and impulse responses.
165
+ fft_size = get_fft_size(frame_size, ir_size, power_of_2=False)
166
+ audio_fft = torch.fft.rfft(audio_frames, fft_size)
167
+ ir_fft = torch.fft.rfft(torch.cat((impulse_response,impulse_response[:,-1:,:]),1), fft_size)
168
+
169
+ # Multiply the FFTs (same as convolution in time).
170
+ audio_ir_fft = torch.multiply(audio_fft, ir_fft)
171
+
172
+ # Take the IFFT to resynthesize audio.
173
+ audio_frames_out = torch.fft.irfft(audio_ir_fft, fft_size)
174
+
175
+ # Overlap Add
176
+ batch_size, n_audio_frames, frame_size = audio_frames_out.size() # # B, n_frames+1, 2*(hop_size+n_mags-1)-1
177
+ fold = torch.nn.Fold(output_size=(1, (n_audio_frames - 1) * hop_size + frame_size),kernel_size=(1, frame_size),stride=(1, hop_size))
178
+ output_signal = fold(audio_frames_out.transpose(1, 2)).squeeze(1).squeeze(1)
179
+
180
+ # Crop and shift the output audio.
181
+ output_signal = crop_and_compensate_delay(output_signal[:,hop_size:], audio_size, ir_size)
182
+ return output_signal
183
+
184
+
185
+ def apply_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1)
186
+ window_size: int = 0,
187
+ causal: bool = False):
188
+ """Apply a window to an impulse response and put in causal form.
189
+ Args:
190
+ impulse_response: A series of impulse responses frames to window, of shape
191
+ [batch, n_frames, ir_size]. ---------> ir_size means size of filter_bank ??????
192
+
193
+ window_size: Size of the window to apply in the time domain. If window_size
194
+ is less than 1, it defaults to the impulse_response size.
195
+ causal: Impulse response input is in causal form (peak in the middle).
196
+ Returns:
197
+ impulse_response: Windowed impulse response in causal form, with last
198
+ dimension cropped to window_size if window_size is greater than 0 and less
199
+ than ir_size.
200
+ """
201
+
202
+ # If IR is in causal form, put it in zero-phase form.
203
+ if causal:
204
+ impulse_response = torch.fftshift(impulse_response, axes=-1)
205
+
206
+ # Get a window for better time/frequency resolution than rectangular.
207
+ # Window defaults to IR size, cannot be bigger.
208
+ ir_size = int(impulse_response.size(-1))
209
+ if (window_size <= 0) or (window_size > ir_size):
210
+ window_size = ir_size
211
+ window = nn.Parameter(torch.hann_window(window_size), requires_grad = False).to(impulse_response)
212
+
213
+ # Zero pad the window and put in in zero-phase form.
214
+ padding = ir_size - window_size
215
+ if padding > 0:
216
+ half_idx = (window_size + 1) // 2
217
+ window = torch.cat([window[half_idx:],
218
+ torch.zeros([padding]),
219
+ window[:half_idx]], axis=0)
220
+ else:
221
+ window = window.roll(window.size(-1)//2, -1)
222
+
223
+ # Apply the window, to get new IR (both in zero-phase form).
224
+ window = window.unsqueeze(0)
225
+ impulse_response = impulse_response * window
226
+
227
+ # Put IR in causal form and trim zero padding.
228
+ if padding > 0:
229
+ first_half_start = (ir_size - (half_idx - 1)) + 1
230
+ second_half_end = half_idx + 1
231
+ impulse_response = torch.cat([impulse_response[..., first_half_start:],
232
+ impulse_response[..., :second_half_end]],
233
+ dim=-1)
234
+ else:
235
+ impulse_response = impulse_response.roll(impulse_response.size(-1)//2, -1)
236
+
237
+ return impulse_response
238
+
239
+
240
+ def apply_dynamic_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1) or 2*n_mag-1
241
+ half_width_frames): # B,n_frames, 1
242
+ ir_size = int(impulse_response.size(-1)) # 2*(n_mag -1) or 2*n_mag-1
243
+
244
+ window = torch.arange(-(ir_size // 2), (ir_size + 1) // 2).to(impulse_response) / half_width_frames
245
+ window[window > 1] = 0
246
+ window = (1 + torch.cos(np.pi * window)) / 2 # B, n_frames, 2*(n_mag -1) or 2*n_mag-1
247
+
248
+ impulse_response = impulse_response.roll(ir_size // 2, -1)
249
+ impulse_response = impulse_response * window
250
+
251
+ return impulse_response
252
+
253
+
254
+ def frequency_impulse_response(magnitudes,
255
+ hann_window = True,
256
+ half_width_frames = None):
257
+
258
+ # Get the IR
259
+ impulse_response = torch.fft.irfft(magnitudes) # B, n_frames, 2*(n_mags-1)
260
+
261
+ # Window and put in causal form.
262
+ if hann_window:
263
+ if half_width_frames is None:
264
+ impulse_response = apply_window_to_impulse_response(impulse_response)
265
+ else:
266
+ impulse_response = apply_dynamic_window_to_impulse_response(impulse_response, half_width_frames)
267
+ else:
268
+ impulse_response = impulse_response.roll(impulse_response.size(-1) // 2, -1)
269
+
270
+ return impulse_response
271
+
272
+
273
+ def frequency_filter(audio,
274
+ magnitudes,
275
+ hann_window=True,
276
+ half_width_frames=None):
277
+
278
+ impulse_response = frequency_impulse_response(magnitudes, hann_window, half_width_frames)
279
+
280
+ return fft_convolve(audio, impulse_response)
281
+
DDSP-SVC/ddsp/loss.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchaudio
6
+ from torch.nn import functional as F
7
+ from .core import upsample
8
+
9
+ class SSSLoss(nn.Module):
10
+ """
11
+ Single-scale Spectral Loss.
12
+ """
13
+
14
+ def __init__(self, n_fft=111, alpha=1.0, overlap=0, eps=1e-7):
15
+ super().__init__()
16
+ self.n_fft = n_fft
17
+ self.alpha = alpha
18
+ self.eps = eps
19
+ self.hop_length = int(n_fft * (1 - overlap)) # 25% of the length
20
+ self.spec = torchaudio.transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, power=1, normalized=True, center=False)
21
+
22
+ def forward(self, x_true, x_pred):
23
+ S_true = self.spec(x_true) + self.eps
24
+ S_pred = self.spec(x_pred) + self.eps
25
+
26
+ converge_term = torch.mean(torch.linalg.norm(S_true - S_pred, dim = (1, 2)) / torch.linalg.norm(S_true + S_pred, dim = (1, 2)))
27
+
28
+ log_term = F.l1_loss(S_true.log(), S_pred.log())
29
+
30
+ loss = converge_term + self.alpha * log_term
31
+ return loss
32
+
33
+
34
+ class RSSLoss(nn.Module):
35
+ '''
36
+ Random-scale Spectral Loss.
37
+ '''
38
+
39
+ def __init__(self, fft_min, fft_max, n_scale, alpha=1.0, overlap=0, eps=1e-7, device='cuda'):
40
+ super().__init__()
41
+ self.fft_min = fft_min
42
+ self.fft_max = fft_max
43
+ self.n_scale = n_scale
44
+ self.lossdict = {}
45
+ for n_fft in range(fft_min, fft_max):
46
+ self.lossdict[n_fft] = SSSLoss(n_fft, alpha, overlap, eps).to(device)
47
+
48
+ def forward(self, x_pred, x_true):
49
+ value = 0.
50
+ n_ffts = torch.randint(self.fft_min, self.fft_max, (self.n_scale,))
51
+ for n_fft in n_ffts:
52
+ loss_func = self.lossdict[int(n_fft)]
53
+ value += loss_func(x_true, x_pred)
54
+ return value / self.n_scale
55
+
56
+
57
+
DDSP-SVC/ddsp/pcmer.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch import nn
4
+ import math
5
+ from functools import partial
6
+ from einops import rearrange, repeat
7
+
8
+ from local_attention import LocalAttention
9
+ import torch.nn.functional as F
10
+ #import fast_transformers.causal_product.causal_product_cuda
11
+
12
+ def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
13
+ b, h, *_ = data.shape
14
+ # (batch size, head, length, model_dim)
15
+
16
+ # normalize model dim
17
+ data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
18
+
19
+ # what is ration?, projection_matrix.shape[0] --> 266
20
+
21
+ ratio = (projection_matrix.shape[0] ** -0.5)
22
+
23
+ projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
24
+ projection = projection.type_as(data)
25
+
26
+ #data_dash = w^T x
27
+ data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
28
+
29
+
30
+ # diag_data = D**2
31
+ diag_data = data ** 2
32
+ diag_data = torch.sum(diag_data, dim=-1)
33
+ diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
34
+ diag_data = diag_data.unsqueeze(dim=-1)
35
+
36
+ #print ()
37
+ if is_query:
38
+ data_dash = ratio * (
39
+ torch.exp(data_dash - diag_data -
40
+ torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
41
+ else:
42
+ data_dash = ratio * (
43
+ torch.exp(data_dash - diag_data + eps))#- torch.max(data_dash)) + eps)
44
+
45
+ return data_dash.type_as(data)
46
+
47
+ def orthogonal_matrix_chunk(cols, qr_uniform_q = False, device = None):
48
+ unstructured_block = torch.randn((cols, cols), device = device)
49
+ q, r = torch.linalg.qr(unstructured_block.cpu(), mode='reduced')
50
+ q, r = map(lambda t: t.to(device), (q, r))
51
+
52
+ # proposed by @Parskatt
53
+ # to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf
54
+ if qr_uniform_q:
55
+ d = torch.diag(r, 0)
56
+ q *= d.sign()
57
+ return q.t()
58
+ def exists(val):
59
+ return val is not None
60
+
61
+ def empty(tensor):
62
+ return tensor.numel() == 0
63
+
64
+ def default(val, d):
65
+ return val if exists(val) else d
66
+
67
+ def cast_tuple(val):
68
+ return (val,) if not isinstance(val, tuple) else val
69
+
70
+ class PCmer(nn.Module):
71
+ """The encoder that is used in the Transformer model."""
72
+
73
+ def __init__(self,
74
+ num_layers,
75
+ num_heads,
76
+ dim_model,
77
+ dim_keys,
78
+ dim_values,
79
+ residual_dropout,
80
+ attention_dropout):
81
+ super().__init__()
82
+ self.num_layers = num_layers
83
+ self.num_heads = num_heads
84
+ self.dim_model = dim_model
85
+ self.dim_values = dim_values
86
+ self.dim_keys = dim_keys
87
+ self.residual_dropout = residual_dropout
88
+ self.attention_dropout = attention_dropout
89
+
90
+ self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
91
+
92
+ # METHODS ########################################################################################################
93
+
94
+ def forward(self, phone, mask=None):
95
+
96
+ # apply all layers to the input
97
+ for (i, layer) in enumerate(self._layers):
98
+ phone = layer(phone, mask)
99
+ # provide the final sequence
100
+ return phone
101
+
102
+
103
+ # ==================================================================================================================== #
104
+ # CLASS _ E N C O D E R L A Y E R #
105
+ # ==================================================================================================================== #
106
+
107
+
108
+ class _EncoderLayer(nn.Module):
109
+ """One layer of the encoder.
110
+
111
+ Attributes:
112
+ attn: (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read the input sequence.
113
+ feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanism.
114
+ """
115
+
116
+ def __init__(self, parent: PCmer):
117
+ """Creates a new instance of ``_EncoderLayer``.
118
+
119
+ Args:
120
+ parent (Encoder): The encoder that the layers is created for.
121
+ """
122
+ super().__init__()
123
+
124
+
125
+ self.conformer = ConformerConvModule(parent.dim_model)
126
+ self.norm = nn.LayerNorm(parent.dim_model)
127
+ self.dropout = nn.Dropout(parent.residual_dropout)
128
+
129
+ # selfatt -> fastatt: performer!
130
+ self.attn = SelfAttention(dim = parent.dim_model,
131
+ heads = parent.num_heads,
132
+ causal = False)
133
+
134
+ # METHODS ########################################################################################################
135
+
136
+ def forward(self, phone, mask=None):
137
+
138
+ # compute attention sub-layer
139
+ phone = phone + (self.attn(self.norm(phone), mask=mask))
140
+
141
+ phone = phone + (self.conformer(phone))
142
+
143
+ return phone
144
+
145
+ def calc_same_padding(kernel_size):
146
+ pad = kernel_size // 2
147
+ return (pad, pad - (kernel_size + 1) % 2)
148
+
149
+ # helper classes
150
+
151
+ class Swish(nn.Module):
152
+ def forward(self, x):
153
+ return x * x.sigmoid()
154
+
155
+ class Transpose(nn.Module):
156
+ def __init__(self, dims):
157
+ super().__init__()
158
+ assert len(dims) == 2, 'dims must be a tuple of two dimensions'
159
+ self.dims = dims
160
+
161
+ def forward(self, x):
162
+ return x.transpose(*self.dims)
163
+
164
+ class GLU(nn.Module):
165
+ def __init__(self, dim):
166
+ super().__init__()
167
+ self.dim = dim
168
+
169
+ def forward(self, x):
170
+ out, gate = x.chunk(2, dim=self.dim)
171
+ return out * gate.sigmoid()
172
+
173
+ class DepthWiseConv1d(nn.Module):
174
+ def __init__(self, chan_in, chan_out, kernel_size, padding):
175
+ super().__init__()
176
+ self.padding = padding
177
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
178
+
179
+ def forward(self, x):
180
+ x = F.pad(x, self.padding)
181
+ return self.conv(x)
182
+
183
+ class ConformerConvModule(nn.Module):
184
+ def __init__(
185
+ self,
186
+ dim,
187
+ causal = False,
188
+ expansion_factor = 2,
189
+ kernel_size = 31,
190
+ dropout = 0.):
191
+ super().__init__()
192
+
193
+ inner_dim = dim * expansion_factor
194
+ padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
195
+
196
+ self.net = nn.Sequential(
197
+ nn.LayerNorm(dim),
198
+ Transpose((1, 2)),
199
+ nn.Conv1d(dim, inner_dim * 2, 1),
200
+ GLU(dim=1),
201
+ DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
202
+ #nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
203
+ Swish(),
204
+ nn.Conv1d(inner_dim, dim, 1),
205
+ Transpose((1, 2)),
206
+ nn.Dropout(dropout)
207
+ )
208
+
209
+ def forward(self, x):
210
+ return self.net(x)
211
+
212
+ def linear_attention(q, k, v):
213
+ if v is None:
214
+ #print (k.size(), q.size())
215
+ out = torch.einsum('...ed,...nd->...ne', k, q)
216
+ return out
217
+
218
+ else:
219
+ k_cumsum = k.sum(dim = -2)
220
+ #k_cumsum = k.sum(dim = -2)
221
+ D_inv = 1. / (torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) + 1e-8)
222
+
223
+ context = torch.einsum('...nd,...ne->...de', k, v)
224
+ #print ("TRUEEE: ", context.size(), q.size(), D_inv.size())
225
+ out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
226
+ return out
227
+
228
+ def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False, device = None):
229
+ nb_full_blocks = int(nb_rows / nb_columns)
230
+ #print (nb_full_blocks)
231
+ block_list = []
232
+
233
+ for _ in range(nb_full_blocks):
234
+ q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
235
+ block_list.append(q)
236
+ # block_list[n] is a orthogonal matrix ... (model_dim * model_dim)
237
+ #print (block_list[0].size(), torch.einsum('...nd,...nd->...n', block_list[0], torch.roll(block_list[0],1,1)))
238
+ #print (nb_rows, nb_full_blocks, nb_columns)
239
+ remaining_rows = nb_rows - nb_full_blocks * nb_columns
240
+ #print (remaining_rows)
241
+ if remaining_rows > 0:
242
+ q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
243
+ #print (q[:remaining_rows].size())
244
+ block_list.append(q[:remaining_rows])
245
+
246
+ final_matrix = torch.cat(block_list)
247
+
248
+ if scaling == 0:
249
+ multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
250
+ elif scaling == 1:
251
+ multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
252
+ else:
253
+ raise ValueError(f'Invalid scaling {scaling}')
254
+
255
+ return torch.diag(multiplier) @ final_matrix
256
+
257
+ class FastAttention(nn.Module):
258
+ def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, no_projection = False):
259
+ super().__init__()
260
+ nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
261
+
262
+ self.dim_heads = dim_heads
263
+ self.nb_features = nb_features
264
+ self.ortho_scaling = ortho_scaling
265
+
266
+ self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling, qr_uniform_q = qr_uniform_q)
267
+ projection_matrix = self.create_projection()
268
+ self.register_buffer('projection_matrix', projection_matrix)
269
+
270
+ self.generalized_attention = generalized_attention
271
+ self.kernel_fn = kernel_fn
272
+
273
+ # if this is turned on, no projection will be used
274
+ # queries and keys will be softmax-ed as in the original efficient attention paper
275
+ self.no_projection = no_projection
276
+
277
+ self.causal = causal
278
+ if causal:
279
+ try:
280
+ import fast_transformers.causal_product.causal_product_cuda
281
+ self.causal_linear_fn = partial(causal_linear_attention)
282
+ except ImportError:
283
+ print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
284
+ self.causal_linear_fn = causal_linear_attention_noncuda
285
+ @torch.no_grad()
286
+ def redraw_projection_matrix(self):
287
+ projections = self.create_projection()
288
+ self.projection_matrix.copy_(projections)
289
+ del projections
290
+
291
+ def forward(self, q, k, v):
292
+ device = q.device
293
+
294
+ if self.no_projection:
295
+ q = q.softmax(dim = -1)
296
+ k = torch.exp(k) if self.causal else k.softmax(dim = -2)
297
+
298
+ elif self.generalized_attention:
299
+ create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
300
+ q, k = map(create_kernel, (q, k))
301
+
302
+ else:
303
+ create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
304
+
305
+ q = create_kernel(q, is_query = True)
306
+ k = create_kernel(k, is_query = False)
307
+
308
+ attn_fn = linear_attention if not self.causal else self.causal_linear_fn
309
+ if v is None:
310
+ out = attn_fn(q, k, None)
311
+ return out
312
+ else:
313
+ out = attn_fn(q, k, v)
314
+ return out
315
+ class SelfAttention(nn.Module):
316
+ def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., no_projection = False):
317
+ super().__init__()
318
+ assert dim % heads == 0, 'dimension must be divisible by number of heads'
319
+ dim_head = default(dim_head, dim // heads)
320
+ inner_dim = dim_head * heads
321
+ self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, no_projection = no_projection)
322
+
323
+ self.heads = heads
324
+ self.global_heads = heads - local_heads
325
+ self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
326
+
327
+ #print (heads, nb_features, dim_head)
328
+ #name_embedding = torch.zeros(110, heads, dim_head, dim_head)
329
+ #self.name_embedding = nn.Parameter(name_embedding, requires_grad=True)
330
+
331
+
332
+ self.to_q = nn.Linear(dim, inner_dim)
333
+ self.to_k = nn.Linear(dim, inner_dim)
334
+ self.to_v = nn.Linear(dim, inner_dim)
335
+ self.to_out = nn.Linear(inner_dim, dim)
336
+ self.dropout = nn.Dropout(dropout)
337
+
338
+ @torch.no_grad()
339
+ def redraw_projection_matrix(self):
340
+ self.fast_attention.redraw_projection_matrix()
341
+ #torch.nn.init.zeros_(self.name_embedding)
342
+ #print (torch.sum(self.name_embedding))
343
+ def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs):
344
+ b, n, _, h, gh = *x.shape, self.heads, self.global_heads
345
+
346
+ cross_attend = exists(context)
347
+
348
+ context = default(context, x)
349
+ context_mask = default(context_mask, mask) if not cross_attend else context_mask
350
+ #print (torch.sum(self.name_embedding))
351
+ q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
352
+
353
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
354
+ (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
355
+
356
+ attn_outs = []
357
+ #print (name)
358
+ #print (self.name_embedding[name].size())
359
+ if not empty(q):
360
+ if exists(context_mask):
361
+ global_mask = context_mask[:, None, :, None]
362
+ v.masked_fill_(~global_mask, 0.)
363
+ if cross_attend:
364
+ pass
365
+ #print (torch.sum(self.name_embedding))
366
+ #out = self.fast_attention(q,self.name_embedding[name],None)
367
+ #print (torch.sum(self.name_embedding[...,-1:]))
368
+ else:
369
+ out = self.fast_attention(q, k, v)
370
+ attn_outs.append(out)
371
+
372
+ if not empty(lq):
373
+ assert not cross_attend, 'local attention is not compatible with cross attention'
374
+ out = self.local_attn(lq, lk, lv, input_mask = mask)
375
+ attn_outs.append(out)
376
+
377
+ out = torch.cat(attn_outs, dim = 1)
378
+ out = rearrange(out, 'b h n d -> b n (h d)')
379
+ out = self.to_out(out)
380
+ return self.dropout(out)
DDSP-SVC/ddsp/unit2control.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gin
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn.utils import weight_norm
7
+
8
+ from .pcmer import PCmer
9
+
10
+
11
+ def split_to_dict(tensor, tensor_splits):
12
+ """Split a tensor into a dictionary of multiple tensors."""
13
+ labels = []
14
+ sizes = []
15
+
16
+ for k, v in tensor_splits.items():
17
+ labels.append(k)
18
+ sizes.append(v)
19
+
20
+ tensors = torch.split(tensor, sizes, dim=-1)
21
+ return dict(zip(labels, tensors))
22
+
23
+
24
+ class Unit2Control(nn.Module):
25
+ def __init__(
26
+ self,
27
+ input_channel,
28
+ n_spk,
29
+ output_splits):
30
+ super().__init__()
31
+ self.output_splits = output_splits
32
+ self.f0_embed = nn.Linear(1, 256)
33
+ self.phase_embed = nn.Linear(1, 256)
34
+ self.volume_embed = nn.Linear(1, 256)
35
+ self.n_spk = n_spk
36
+ if n_spk is not None and n_spk > 1:
37
+ self.spk_embed = nn.Embedding(n_spk, 256)
38
+
39
+ # conv in stack
40
+ self.stack = nn.Sequential(
41
+ nn.Conv1d(input_channel, 256, 3, 1, 1),
42
+ nn.GroupNorm(4, 256),
43
+ nn.LeakyReLU(),
44
+ nn.Conv1d(256, 256, 3, 1, 1))
45
+
46
+ # transformer
47
+ self.decoder = PCmer(
48
+ num_layers=3,
49
+ num_heads=8,
50
+ dim_model=256,
51
+ dim_keys=256,
52
+ dim_values=256,
53
+ residual_dropout=0.1,
54
+ attention_dropout=0.1)
55
+ self.norm = nn.LayerNorm(256)
56
+
57
+ # out
58
+ self.n_out = sum([v for k, v in output_splits.items()])
59
+ self.dense_out = weight_norm(
60
+ nn.Linear(256, self.n_out))
61
+
62
+ def forward(self, units, f0, phase, volume, spk_id = None, spk_mix_dict = None):
63
+
64
+ '''
65
+ input:
66
+ B x n_frames x n_unit
67
+ return:
68
+ dict of B x n_frames x feat
69
+ '''
70
+
71
+ x = self.stack(units.transpose(1,2)).transpose(1,2)
72
+ x = x + self.f0_embed((1+ f0 / 700).log()) + self.phase_embed(phase / np.pi) + self.volume_embed(volume)
73
+ if self.n_spk is not None and self.n_spk > 1:
74
+ if spk_mix_dict is not None:
75
+ for k, v in spk_mix_dict.items():
76
+ spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
77
+ x = x + v * self.spk_embed(spk_id_torch - 1)
78
+ else:
79
+ x = x + self.spk_embed(spk_id - 1)
80
+ x = self.decoder(x)
81
+ x = self.norm(x)
82
+ e = self.dense_out(x)
83
+ controls = split_to_dict(e, self.output_splits)
84
+
85
+ return controls
86
+
DDSP-SVC/ddsp/vocoder.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import yaml
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import pyworld as pw
7
+ import parselmouth
8
+ import torchcrepe
9
+ import resampy
10
+ from transformers import HubertModel, Wav2Vec2FeatureExtractor
11
+ from fairseq import checkpoint_utils
12
+ from encoder.hubert.model import HubertSoft
13
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
14
+ from torchaudio.transforms import Resample
15
+ from .unit2control import Unit2Control
16
+ from .core import frequency_filter, upsample, remove_above_fmax, MaskedAvgPool1d, MedianPool1d
17
+ import time
18
+
19
+ CREPE_RESAMPLE_KERNEL = {}
20
+
21
+ class F0_Extractor:
22
+ def __init__(self, f0_extractor, sample_rate = 44100, hop_size = 512, f0_min = 65, f0_max = 800):
23
+ self.f0_extractor = f0_extractor
24
+ self.sample_rate = sample_rate
25
+ self.hop_size = hop_size
26
+ self.f0_min = f0_min
27
+ self.f0_max = f0_max
28
+ if f0_extractor == 'crepe':
29
+ key_str = str(sample_rate)
30
+ if key_str not in CREPE_RESAMPLE_KERNEL:
31
+ CREPE_RESAMPLE_KERNEL[key_str] = Resample(sample_rate, 16000, lowpass_filter_width = 128)
32
+ self.resample_kernel = CREPE_RESAMPLE_KERNEL[key_str]
33
+
34
+ def extract(self, audio, uv_interp = False, device = None, silence_front = 0): # audio: 1d numpy array
35
+ # extractor start time
36
+ n_frames = int(len(audio) // self.hop_size) + 1
37
+
38
+ start_frame = int(silence_front * self.sample_rate / self.hop_size)
39
+ real_silence_front = start_frame * self.hop_size / self.sample_rate
40
+ audio = audio[int(np.round(real_silence_front * self.sample_rate)) : ]
41
+
42
+ # extract f0 using parselmouth
43
+ if self.f0_extractor == 'parselmouth':
44
+ f0 = parselmouth.Sound(audio, self.sample_rate).to_pitch_ac(
45
+ time_step = self.hop_size / self.sample_rate,
46
+ voicing_threshold = 0.6,
47
+ pitch_floor = self.f0_min,
48
+ pitch_ceiling = self.f0_max).selected_array['frequency']
49
+ pad_size = start_frame + (int(len(audio) // self.hop_size) - len(f0) + 1) // 2
50
+ f0 = np.pad(f0,(pad_size, n_frames - len(f0) - pad_size))
51
+
52
+ # extract f0 using dio
53
+ elif self.f0_extractor == 'dio':
54
+ _f0, t = pw.dio(
55
+ audio.astype('double'),
56
+ self.sample_rate,
57
+ f0_floor = self.f0_min,
58
+ f0_ceil = self.f0_max,
59
+ channels_in_octave=2,
60
+ frame_period = (1000 * self.hop_size / self.sample_rate))
61
+ f0 = pw.stonemask(audio.astype('double'), _f0, t, self.sample_rate)
62
+ f0 = np.pad(f0.astype('float'), (start_frame, n_frames - len(f0) - start_frame))
63
+
64
+ # extract f0 using harvest
65
+ elif self.f0_extractor == 'harvest':
66
+ f0, _ = pw.harvest(
67
+ audio.astype('double'),
68
+ self.sample_rate,
69
+ f0_floor = self.f0_min,
70
+ f0_ceil = self.f0_max,
71
+ frame_period = (1000 * self.hop_size / self.sample_rate))
72
+ f0 = np.pad(f0.astype('float'), (start_frame, n_frames - len(f0) - start_frame))
73
+
74
+ # extract f0 using crepe
75
+ elif self.f0_extractor == 'crepe':
76
+ if device is None:
77
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
78
+ resample_kernel = self.resample_kernel.to(device)
79
+ wav16k_torch = resample_kernel(torch.FloatTensor(audio).unsqueeze(0).to(device))
80
+
81
+ f0, pd = torchcrepe.predict(wav16k_torch, 16000, 80, self.f0_min, self.f0_max, pad=True, model='full', batch_size=512, device=device, return_periodicity=True)
82
+ pd = MedianPool1d(pd, 4)
83
+ f0 = torchcrepe.threshold.At(0.05)(f0, pd)
84
+ f0 = MaskedAvgPool1d(f0, 4)
85
+
86
+ f0 = f0.squeeze(0).cpu().numpy()
87
+ f0 = np.array([f0[int(min(int(np.round(n * self.hop_size / self.sample_rate / 0.005)), len(f0) - 1))] for n in range(n_frames - start_frame)])
88
+ f0 = np.pad(f0, (start_frame, 0))
89
+
90
+ else:
91
+ raise ValueError(f" [x] Unknown f0 extractor: {f0_extractor}")
92
+
93
+ # interpolate the unvoiced f0
94
+ if uv_interp:
95
+ uv = f0 == 0
96
+ if len(f0[~uv]) > 0:
97
+ f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
98
+ f0[f0 < self.f0_min] = self.f0_min
99
+ return f0
100
+
101
+
102
+ class Volume_Extractor:
103
+ def __init__(self, hop_size = 512):
104
+ self.hop_size = hop_size
105
+
106
+ def extract(self, audio): # audio: 1d numpy array
107
+ n_frames = int(len(audio) // self.hop_size) + 1
108
+ audio2 = audio ** 2
109
+ audio2 = np.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect')
110
+ volume = np.array([np.mean(audio2[int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)])
111
+ volume = np.sqrt(volume)
112
+ return volume
113
+
114
+
115
+ class Units_Encoder:
116
+ def __init__(self, encoder, encoder_ckpt, encoder_sample_rate = 16000, encoder_hop_size = 320, device = None,
117
+ cnhubertsoft_gate=10):
118
+ if device is None:
119
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
120
+ self.device = device
121
+
122
+ is_loaded_encoder = False
123
+ if encoder == 'hubertsoft':
124
+ self.model = Audio2HubertSoft(encoder_ckpt).to(device)
125
+ is_loaded_encoder = True
126
+ if encoder == 'hubertbase':
127
+ self.model = Audio2HubertBase(encoder_ckpt, device=device)
128
+ is_loaded_encoder = True
129
+ if encoder == 'hubertbase768':
130
+ self.model = Audio2HubertBase768(encoder_ckpt, device=device)
131
+ is_loaded_encoder = True
132
+ if encoder == 'contentvec':
133
+ self.model = Audio2ContentVec(encoder_ckpt, device=device)
134
+ is_loaded_encoder = True
135
+ if encoder == 'contentvec768':
136
+ self.model = Audio2ContentVec768(encoder_ckpt, device=device)
137
+ is_loaded_encoder = True
138
+ if encoder == 'contentvec768l12':
139
+ self.model = Audio2ContentVec768L12(encoder_ckpt, device=device)
140
+ is_loaded_encoder = True
141
+ if encoder == 'cnhubertsoftfish':
142
+ self.model = CNHubertSoftFish(encoder_ckpt, device=device, gate_size=cnhubertsoft_gate)
143
+ is_loaded_encoder = True
144
+ if not is_loaded_encoder:
145
+ raise ValueError(f" [x] Unknown units encoder: {encoder}")
146
+
147
+ self.resample_kernel = {}
148
+ self.encoder_sample_rate = encoder_sample_rate
149
+ self.encoder_hop_size = encoder_hop_size
150
+
151
+ def encode(self,
152
+ audio, # B, T
153
+ sample_rate,
154
+ hop_size):
155
+
156
+ # resample
157
+ if sample_rate == self.encoder_sample_rate:
158
+ audio_res = audio
159
+ else:
160
+ key_str = str(sample_rate)
161
+ if key_str not in self.resample_kernel:
162
+ self.resample_kernel[key_str] = Resample(sample_rate, self.encoder_sample_rate, lowpass_filter_width = 128).to(self.device)
163
+ audio_res = self.resample_kernel[key_str](audio)
164
+
165
+ # encode
166
+ if audio_res.size(-1) < self.encoder_hop_size:
167
+ audio_res = torch.nn.functional.pad(audio, (0, self.encoder_hop_size - audio_res.size(-1)))
168
+ units = self.model(audio_res)
169
+
170
+ # alignment
171
+ n_frames = audio.size(-1) // hop_size + 1
172
+ ratio = (hop_size / sample_rate) / (self.encoder_hop_size / self.encoder_sample_rate)
173
+ index = torch.clamp(torch.round(ratio * torch.arange(n_frames).to(self.device)).long(), max = units.size(1) - 1)
174
+ units_aligned = torch.gather(units, 1, index.unsqueeze(0).unsqueeze(-1).repeat([1, 1, units.size(-1)]))
175
+ return units_aligned
176
+
177
+ class Audio2HubertSoft(torch.nn.Module):
178
+ def __init__(self, path, h_sample_rate = 16000, h_hop_size = 320):
179
+ super().__init__()
180
+ print(' [Encoder Model] HuBERT Soft')
181
+ self.hubert = HubertSoft()
182
+ print(' [Loading] ' + path)
183
+ checkpoint = torch.load(path)
184
+ consume_prefix_in_state_dict_if_present(checkpoint, "module.")
185
+ self.hubert.load_state_dict(checkpoint)
186
+ self.hubert.eval()
187
+
188
+ def forward(self,
189
+ audio): # B, T
190
+ with torch.inference_mode():
191
+ units = self.hubert.units(audio.unsqueeze(1))
192
+ return units
193
+
194
+
195
+ class Audio2ContentVec():
196
+ def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'):
197
+ self.device = device
198
+ print(' [Encoder Model] Content Vec')
199
+ print(' [Loading] ' + path)
200
+ self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", )
201
+ self.hubert = self.models[0]
202
+ self.hubert = self.hubert.to(self.device)
203
+ self.hubert.eval()
204
+
205
+ def __call__(self,
206
+ audio): # B, T
207
+ # wav_tensor = torch.from_numpy(audio).to(self.device)
208
+ wav_tensor = audio
209
+ feats = wav_tensor.view(1, -1)
210
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
211
+ inputs = {
212
+ "source": feats.to(wav_tensor.device),
213
+ "padding_mask": padding_mask.to(wav_tensor.device),
214
+ "output_layer": 9, # layer 9
215
+ }
216
+ with torch.no_grad():
217
+ logits = self.hubert.extract_features(**inputs)
218
+ feats = self.hubert.final_proj(logits[0])
219
+ units = feats # .transpose(2, 1)
220
+ return units
221
+
222
+
223
+ class Audio2ContentVec768():
224
+ def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'):
225
+ self.device = device
226
+ print(' [Encoder Model] Content Vec')
227
+ print(' [Loading] ' + path)
228
+ self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", )
229
+ self.hubert = self.models[0]
230
+ self.hubert = self.hubert.to(self.device)
231
+ self.hubert.eval()
232
+
233
+ def __call__(self,
234
+ audio): # B, T
235
+ # wav_tensor = torch.from_numpy(audio).to(self.device)
236
+ wav_tensor = audio
237
+ feats = wav_tensor.view(1, -1)
238
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
239
+ inputs = {
240
+ "source": feats.to(wav_tensor.device),
241
+ "padding_mask": padding_mask.to(wav_tensor.device),
242
+ "output_layer": 9, # layer 9
243
+ }
244
+ with torch.no_grad():
245
+ logits = self.hubert.extract_features(**inputs)
246
+ feats = logits[0]
247
+ units = feats # .transpose(2, 1)
248
+ return units
249
+
250
+
251
+ class Audio2ContentVec768L12():
252
+ def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'):
253
+ self.device = device
254
+ print(' [Encoder Model] Content Vec')
255
+ print(' [Loading] ' + path)
256
+ self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", )
257
+ self.hubert = self.models[0]
258
+ self.hubert = self.hubert.to(self.device)
259
+ self.hubert.eval()
260
+
261
+ def __call__(self,
262
+ audio): # B, T
263
+ # wav_tensor = torch.from_numpy(audio).to(self.device)
264
+ wav_tensor = audio
265
+ feats = wav_tensor.view(1, -1)
266
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
267
+ inputs = {
268
+ "source": feats.to(wav_tensor.device),
269
+ "padding_mask": padding_mask.to(wav_tensor.device),
270
+ "output_layer": 12, # layer 12
271
+ }
272
+ with torch.no_grad():
273
+ logits = self.hubert.extract_features(**inputs)
274
+ feats = logits[0]
275
+ units = feats # .transpose(2, 1)
276
+ return units
277
+
278
+
279
+ class CNHubertSoftFish(torch.nn.Module):
280
+ def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu', gate_size=10):
281
+ super().__init__()
282
+ self.device = device
283
+ self.gate_size = gate_size
284
+
285
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
286
+ "./pretrain/TencentGameMate/chinese-hubert-base")
287
+ self.model = HubertModel.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base")
288
+ self.proj = torch.nn.Sequential(torch.nn.Dropout(0.1), torch.nn.Linear(768, 256))
289
+ # self.label_embedding = nn.Embedding(128, 256)
290
+
291
+ state_dict = torch.load(path, map_location=device)
292
+ self.load_state_dict(state_dict)
293
+
294
+ @torch.no_grad()
295
+ def forward(self, audio):
296
+ input_values = self.feature_extractor(
297
+ audio, sampling_rate=16000, return_tensors="pt"
298
+ ).input_values
299
+ input_values = input_values.to(self.model.device)
300
+
301
+ return self._forward(input_values[0])
302
+
303
+ @torch.no_grad()
304
+ def _forward(self, input_values):
305
+ features = self.model(input_values)
306
+ features = self.proj(features.last_hidden_state)
307
+
308
+ # Top-k gating
309
+ topk, indices = torch.topk(features, self.gate_size, dim=2)
310
+ features = torch.zeros_like(features).scatter(2, indices, topk)
311
+ features = features / features.sum(2, keepdim=True)
312
+
313
+ return features.to(self.device) # .transpose(1, 2)
314
+
315
+
316
+ class Audio2HubertBase():
317
+ def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'):
318
+ self.device = device
319
+ print(' [Encoder Model] HuBERT Base')
320
+ print(' [Loading] ' + path)
321
+ self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", )
322
+ self.hubert = self.models[0]
323
+ self.hubert = self.hubert.to(self.device)
324
+ self.hubert = self.hubert.float()
325
+ self.hubert.eval()
326
+
327
+ def __call__(self,
328
+ audio): # B, T
329
+ with torch.no_grad():
330
+ padding_mask = torch.BoolTensor(audio.shape).fill_(False)
331
+ inputs = {
332
+ "source": audio.to(self.device),
333
+ "padding_mask": padding_mask.to(self.device),
334
+ "output_layer": 9, # layer 9
335
+ }
336
+ logits = self.hubert.extract_features(**inputs)
337
+ units = self.hubert.final_proj(logits[0])
338
+ return units
339
+
340
+
341
+ class Audio2HubertBase768():
342
+ def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'):
343
+ self.device = device
344
+ print(' [Encoder Model] HuBERT Base')
345
+ print(' [Loading] ' + path)
346
+ self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", )
347
+ self.hubert = self.models[0]
348
+ self.hubert = self.hubert.to(self.device)
349
+ self.hubert = self.hubert.float()
350
+ self.hubert.eval()
351
+
352
+ def __call__(self,
353
+ audio): # B, T
354
+ with torch.no_grad():
355
+ padding_mask = torch.BoolTensor(audio.shape).fill_(False)
356
+ inputs = {
357
+ "source": audio.to(self.device),
358
+ "padding_mask": padding_mask.to(self.device),
359
+ "output_layer": 9, # layer 9
360
+ }
361
+ logits = self.hubert.extract_features(**inputs)
362
+ units = logits[0]
363
+ return units
364
+
365
+
366
+ class DotDict(dict):
367
+ def __getattr__(*args):
368
+ val = dict.get(*args)
369
+ return DotDict(val) if type(val) is dict else val
370
+
371
+ __setattr__ = dict.__setitem__
372
+ __delattr__ = dict.__delitem__
373
+
374
+ def load_model(
375
+ model_path,
376
+ device='cpu'):
377
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
378
+ with open(config_file, "r") as config:
379
+ args = yaml.safe_load(config)
380
+ args = DotDict(args)
381
+
382
+ # load model
383
+ model = None
384
+
385
+ if args.model.type == 'Sins':
386
+ model = Sins(
387
+ sampling_rate=args.data.sampling_rate,
388
+ block_size=args.data.block_size,
389
+ n_harmonics=args.model.n_harmonics,
390
+ n_mag_allpass=args.model.n_mag_allpass,
391
+ n_mag_noise=args.model.n_mag_noise,
392
+ n_unit=args.data.encoder_out_channels,
393
+ n_spk=args.model.n_spk)
394
+
395
+ elif args.model.type == 'CombSub':
396
+ model = CombSub(
397
+ sampling_rate=args.data.sampling_rate,
398
+ block_size=args.data.block_size,
399
+ n_mag_allpass=args.model.n_mag_allpass,
400
+ n_mag_harmonic=args.model.n_mag_harmonic,
401
+ n_mag_noise=args.model.n_mag_noise,
402
+ n_unit=args.data.encoder_out_channels,
403
+ n_spk=args.model.n_spk)
404
+
405
+ elif args.model.type == 'CombSubFast':
406
+ model = CombSubFast(
407
+ sampling_rate=args.data.sampling_rate,
408
+ block_size=args.data.block_size,
409
+ n_unit=args.data.encoder_out_channels,
410
+ n_spk=args.model.n_spk)
411
+
412
+ else:
413
+ raise ValueError(f" [x] Unknown Model: {args.model.type}")
414
+
415
+ print(' [Loading] ' + model_path)
416
+ ckpt = torch.load(model_path, map_location=torch.device(device))
417
+ model.to(device)
418
+ model.load_state_dict(ckpt['model'])
419
+ model.eval()
420
+ return model, args
421
+
422
+
423
+ class Sins(torch.nn.Module):
424
+ def __init__(self,
425
+ sampling_rate,
426
+ block_size,
427
+ n_harmonics,
428
+ n_mag_allpass,
429
+ n_mag_noise,
430
+ n_unit=256,
431
+ n_spk=1):
432
+ super().__init__()
433
+
434
+ print(' [DDSP Model] Sinusoids Additive Synthesiser')
435
+
436
+ # params
437
+ self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
438
+ self.register_buffer("block_size", torch.tensor(block_size))
439
+ # Unit2Control
440
+ split_map = {
441
+ 'amplitudes': n_harmonics,
442
+ 'group_delay': n_mag_allpass,
443
+ 'noise_magnitude': n_mag_noise,
444
+ }
445
+ self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
446
+
447
+ def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, max_upsample_dim=32):
448
+ '''
449
+ units_frames: B x n_frames x n_unit
450
+ f0_frames: B x n_frames x 1
451
+ volume_frames: B x n_frames x 1
452
+ spk_id: B x 1
453
+ '''
454
+ # exciter phase
455
+ f0 = upsample(f0_frames, self.block_size)
456
+ if infer:
457
+ x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
458
+ else:
459
+ x = torch.cumsum(f0 / self.sampling_rate, axis=1)
460
+ if initial_phase is not None:
461
+ x += initial_phase.to(x) / 2 / np.pi
462
+ x = x - torch.round(x)
463
+ x = x.to(f0)
464
+
465
+ phase = 2 * np.pi * x
466
+ phase_frames = phase[:, ::self.block_size, :]
467
+
468
+ # parameter prediction
469
+ ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
470
+
471
+ amplitudes_frames = torch.exp(ctrls['amplitudes'])/ 128
472
+ group_delay = np.pi * torch.tanh(ctrls['group_delay'])
473
+ noise_param = torch.exp(ctrls['noise_magnitude']) / 128
474
+
475
+ # sinusoids exciter signal
476
+ amplitudes_frames = remove_above_fmax(amplitudes_frames, f0_frames, self.sampling_rate / 2, level_start = 1)
477
+ n_harmonic = amplitudes_frames.shape[-1]
478
+ level_harmonic = torch.arange(1, n_harmonic + 1).to(phase)
479
+ sinusoids = 0.
480
+ for n in range(( n_harmonic - 1) // max_upsample_dim + 1):
481
+ start = n * max_upsample_dim
482
+ end = (n + 1) * max_upsample_dim
483
+ phases = phase * level_harmonic[start:end]
484
+ amplitudes = upsample(amplitudes_frames[:,:,start:end], self.block_size)
485
+ sinusoids += (torch.sin(phases) * amplitudes).sum(-1)
486
+
487
+ # harmonic part filter (apply group-delay)
488
+ harmonic = frequency_filter(
489
+ sinusoids,
490
+ torch.exp(1.j * torch.cumsum(group_delay, axis = -1)),
491
+ hann_window = False)
492
+
493
+ # noise part filter
494
+ noise = torch.rand_like(harmonic) * 2 - 1
495
+ noise = frequency_filter(
496
+ noise,
497
+ torch.complex(noise_param, torch.zeros_like(noise_param)),
498
+ hann_window = True)
499
+
500
+ signal = harmonic + noise
501
+
502
+ return signal, phase, (harmonic, noise) #, (noise_param, noise_param)
503
+
504
+ class CombSubFast(torch.nn.Module):
505
+ def __init__(self,
506
+ sampling_rate,
507
+ block_size,
508
+ n_unit=256,
509
+ n_spk=1):
510
+ super().__init__()
511
+
512
+ print(' [DDSP Model] Combtooth Subtractive Synthesiser')
513
+ # params
514
+ self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
515
+ self.register_buffer("block_size", torch.tensor(block_size))
516
+ self.register_buffer("window", torch.sqrt(torch.hann_window(2 * block_size)))
517
+ #Unit2Control
518
+ split_map = {
519
+ 'harmonic_magnitude': block_size + 1,
520
+ 'harmonic_phase': block_size + 1,
521
+ 'noise_magnitude': block_size + 1
522
+ }
523
+ self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
524
+
525
+ def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs):
526
+ '''
527
+ units_frames: B x n_frames x n_unit
528
+ f0_frames: B x n_frames x 1
529
+ volume_frames: B x n_frames x 1
530
+ spk_id: B x 1
531
+ '''
532
+ # exciter phase
533
+ f0 = upsample(f0_frames, self.block_size)
534
+ if infer:
535
+ x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
536
+ else:
537
+ x = torch.cumsum(f0 / self.sampling_rate, axis=1)
538
+ if initial_phase is not None:
539
+ x += initial_phase.to(x) / 2 / np.pi
540
+ x = x - torch.round(x)
541
+ x = x.to(f0)
542
+
543
+ phase_frames = 2 * np.pi * x[:, ::self.block_size, :]
544
+
545
+ # parameter prediction
546
+ ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
547
+
548
+ src_filter = torch.exp(ctrls['harmonic_magnitude'] + 1.j * np.pi * ctrls['harmonic_phase'])
549
+ src_filter = torch.cat((src_filter, src_filter[:,-1:,:]), 1)
550
+ noise_filter= torch.exp(ctrls['noise_magnitude']) / 128
551
+ noise_filter = torch.cat((noise_filter, noise_filter[:,-1:,:]), 1)
552
+
553
+ # combtooth exciter signal
554
+ combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3))
555
+ combtooth = combtooth.squeeze(-1)
556
+ combtooth_frames = F.pad(combtooth, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size)
557
+ combtooth_frames = combtooth_frames * self.window
558
+ combtooth_fft = torch.fft.rfft(combtooth_frames, 2 * self.block_size)
559
+
560
+ # noise exciter signal
561
+ noise = torch.rand_like(combtooth) * 2 - 1
562
+ noise_frames = F.pad(noise, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size)
563
+ noise_frames = noise_frames * self.window
564
+ noise_fft = torch.fft.rfft(noise_frames, 2 * self.block_size)
565
+
566
+ # apply the filters
567
+ signal_fft = combtooth_fft * src_filter + noise_fft * noise_filter
568
+
569
+ # take the ifft to resynthesize audio.
570
+ signal_frames_out = torch.fft.irfft(signal_fft, 2 * self.block_size) * self.window
571
+
572
+ # overlap add
573
+ fold = torch.nn.Fold(output_size=(1, (signal_frames_out.size(1) + 1) * self.block_size), kernel_size=(1, 2 * self.block_size), stride=(1, self.block_size))
574
+ signal = fold(signal_frames_out.transpose(1, 2))[:, 0, 0, self.block_size : -self.block_size]
575
+
576
+ return signal, phase_frames, (signal, signal)
577
+
578
+ class CombSub(torch.nn.Module):
579
+ def __init__(self,
580
+ sampling_rate,
581
+ block_size,
582
+ n_mag_allpass,
583
+ n_mag_harmonic,
584
+ n_mag_noise,
585
+ n_unit=256,
586
+ n_spk=1):
587
+ super().__init__()
588
+
589
+ print(' [DDSP Model] Combtooth Subtractive Synthesiser (Old Version)')
590
+ # params
591
+ self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
592
+ self.register_buffer("block_size", torch.tensor(block_size))
593
+ #Unit2Control
594
+ split_map = {
595
+ 'group_delay': n_mag_allpass,
596
+ 'harmonic_magnitude': n_mag_harmonic,
597
+ 'noise_magnitude': n_mag_noise
598
+ }
599
+ self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
600
+
601
+ def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs):
602
+ '''
603
+ units_frames: B x n_frames x n_unit
604
+ f0_frames: B x n_frames x 1
605
+ volume_frames: B x n_frames x 1
606
+ spk_id: B x 1
607
+ '''
608
+ # exciter phase
609
+ f0 = upsample(f0_frames, self.block_size)
610
+ if infer:
611
+ x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
612
+ else:
613
+ x = torch.cumsum(f0 / self.sampling_rate, axis=1)
614
+ if initial_phase is not None:
615
+ x += initial_phase.to(x) / 2 / np.pi
616
+ x = x - torch.round(x)
617
+ x = x.to(f0)
618
+
619
+ phase_frames = 2 * np.pi * x[:, ::self.block_size, :]
620
+
621
+ # parameter prediction
622
+ ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
623
+
624
+ group_delay = np.pi * torch.tanh(ctrls['group_delay'])
625
+ src_param = torch.exp(ctrls['harmonic_magnitude'])
626
+ noise_param = torch.exp(ctrls['noise_magnitude']) / 128
627
+
628
+ # combtooth exciter signal
629
+ combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3))
630
+ combtooth = combtooth.squeeze(-1)
631
+
632
+ # harmonic part filter (using dynamic-windowed LTV-FIR, with group-delay prediction)
633
+ harmonic = frequency_filter(
634
+ combtooth,
635
+ torch.exp(1.j * torch.cumsum(group_delay, axis = -1)),
636
+ hann_window = False)
637
+ harmonic = frequency_filter(
638
+ harmonic,
639
+ torch.complex(src_param, torch.zeros_like(src_param)),
640
+ hann_window = True,
641
+ half_width_frames = 1.5 * self.sampling_rate / (f0_frames + 1e-3))
642
+
643
+ # noise part filter (using constant-windowed LTV-FIR, without group-delay)
644
+ noise = torch.rand_like(harmonic) * 2 - 1
645
+ noise = frequency_filter(
646
+ noise,
647
+ torch.complex(noise_param, torch.zeros_like(noise_param)),
648
+ hann_window = True)
649
+
650
+ signal = harmonic + noise
651
+
652
+ return signal, phase_frames, (harmonic, noise)
DDSP-SVC/diffusion/data_loaders.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import librosa
5
+ import torch
6
+ import random
7
+ from tqdm import tqdm
8
+ from torch.utils.data import Dataset
9
+
10
+ def traverse_dir(
11
+ root_dir,
12
+ extension,
13
+ amount=None,
14
+ str_include=None,
15
+ str_exclude=None,
16
+ is_pure=False,
17
+ is_sort=False,
18
+ is_ext=True):
19
+
20
+ file_list = []
21
+ cnt = 0
22
+ for root, _, files in os.walk(root_dir):
23
+ for file in files:
24
+ if file.endswith(extension):
25
+ # path
26
+ mix_path = os.path.join(root, file)
27
+ pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
28
+
29
+ # amount
30
+ if (amount is not None) and (cnt == amount):
31
+ if is_sort:
32
+ file_list.sort()
33
+ return file_list
34
+
35
+ # check string
36
+ if (str_include is not None) and (str_include not in pure_path):
37
+ continue
38
+ if (str_exclude is not None) and (str_exclude in pure_path):
39
+ continue
40
+
41
+ if not is_ext:
42
+ ext = pure_path.split('.')[-1]
43
+ pure_path = pure_path[:-(len(ext)+1)]
44
+ file_list.append(pure_path)
45
+ cnt += 1
46
+ if is_sort:
47
+ file_list.sort()
48
+ return file_list
49
+
50
+
51
+ def get_data_loaders(args, whole_audio=False):
52
+ data_train = AudioDataset(
53
+ args.data.train_path,
54
+ waveform_sec=args.data.duration,
55
+ hop_size=args.data.block_size,
56
+ sample_rate=args.data.sampling_rate,
57
+ load_all_data=args.train.cache_all_data,
58
+ whole_audio=whole_audio,
59
+ n_spk=args.model.n_spk,
60
+ device=args.train.cache_device,
61
+ fp16=args.train.cache_fp16,
62
+ use_aug=True)
63
+ loader_train = torch.utils.data.DataLoader(
64
+ data_train ,
65
+ batch_size=args.train.batch_size if not whole_audio else 1,
66
+ shuffle=True,
67
+ num_workers=args.train.num_workers if args.train.cache_device=='cpu' else 0,
68
+ persistent_workers=(args.train.num_workers > 0) if args.train.cache_device=='cpu' else False,
69
+ pin_memory=True if args.train.cache_device=='cpu' else False
70
+ )
71
+ data_valid = AudioDataset(
72
+ args.data.valid_path,
73
+ waveform_sec=args.data.duration,
74
+ hop_size=args.data.block_size,
75
+ sample_rate=args.data.sampling_rate,
76
+ load_all_data=args.train.cache_all_data,
77
+ whole_audio=True,
78
+ n_spk=args.model.n_spk)
79
+ loader_valid = torch.utils.data.DataLoader(
80
+ data_valid,
81
+ batch_size=1,
82
+ shuffle=False,
83
+ num_workers=0,
84
+ pin_memory=True
85
+ )
86
+ return loader_train, loader_valid
87
+
88
+
89
+ class AudioDataset(Dataset):
90
+ def __init__(
91
+ self,
92
+ path_root,
93
+ waveform_sec,
94
+ hop_size,
95
+ sample_rate,
96
+ load_all_data=True,
97
+ whole_audio=False,
98
+ n_spk=1,
99
+ device='cpu',
100
+ fp16=False,
101
+ use_aug=False,
102
+ ):
103
+ super().__init__()
104
+
105
+ self.waveform_sec = waveform_sec
106
+ self.sample_rate = sample_rate
107
+ self.hop_size = hop_size
108
+ self.path_root = path_root
109
+ self.paths = traverse_dir(
110
+ os.path.join(path_root, 'audio'),
111
+ extension='wav',
112
+ is_pure=True,
113
+ is_sort=True,
114
+ is_ext=False
115
+ )
116
+ self.whole_audio = whole_audio
117
+ self.use_aug = use_aug
118
+ self.data_buffer={}
119
+ self.pitch_aug_dict = np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item()
120
+ if load_all_data:
121
+ print('Load all the data from :', path_root)
122
+ else:
123
+ print('Load the f0, volume data from :', path_root)
124
+ for name in tqdm(self.paths, total=len(self.paths)):
125
+ path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
126
+ duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate)
127
+
128
+ path_f0 = os.path.join(self.path_root, 'f0', name) + '.npy'
129
+ f0 = np.load(path_f0)
130
+ f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(device)
131
+
132
+ path_volume = os.path.join(self.path_root, 'volume', name) + '.npy'
133
+ volume = np.load(path_volume)
134
+ volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device)
135
+
136
+ path_augvol = os.path.join(self.path_root, 'aug_vol', name) + '.npy'
137
+ aug_vol = np.load(path_augvol)
138
+ aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device)
139
+
140
+ if n_spk is not None and n_spk > 1:
141
+ spk_id = int(os.path.dirname(name)) if str.isdigit(os.path.dirname(name)) else 0
142
+ if spk_id < 1 or spk_id > n_spk:
143
+ raise ValueError(' [x] Muiti-speaker traing error : spk_id must be a positive integer from 1 to n_spk ')
144
+ else:
145
+ spk_id = 1
146
+ spk_id = torch.LongTensor(np.array([spk_id])).to(device)
147
+
148
+ if load_all_data:
149
+ '''
150
+ audio, sr = librosa.load(path_audio, sr=self.sample_rate)
151
+ if len(audio.shape) > 1:
152
+ audio = librosa.to_mono(audio)
153
+ audio = torch.from_numpy(audio).to(device)
154
+ '''
155
+ path_mel = os.path.join(self.path_root, 'mel', name) + '.npy'
156
+ mel = np.load(path_mel)
157
+ mel = torch.from_numpy(mel).to(device)
158
+
159
+ path_augmel = os.path.join(self.path_root, 'aug_mel', name) + '.npy'
160
+ aug_mel = np.load(path_augmel)
161
+ aug_mel = torch.from_numpy(aug_mel).to(device)
162
+
163
+ path_units = os.path.join(self.path_root, 'units', name) + '.npy'
164
+ units = np.load(path_units)
165
+ units = torch.from_numpy(units).to(device)
166
+
167
+ if fp16:
168
+ mel = mel.half()
169
+ aug_mel = aug_mel.half()
170
+ units = units.half()
171
+
172
+ self.data_buffer[name] = {
173
+ 'duration': duration,
174
+ 'mel': mel,
175
+ 'aug_mel': aug_mel,
176
+ 'units': units,
177
+ 'f0': f0,
178
+ 'volume': volume,
179
+ 'aug_vol': aug_vol,
180
+ 'spk_id': spk_id
181
+ }
182
+ else:
183
+ self.data_buffer[name] = {
184
+ 'duration': duration,
185
+ 'f0': f0,
186
+ 'volume': volume,
187
+ 'aug_vol': aug_vol,
188
+ 'spk_id': spk_id
189
+ }
190
+
191
+
192
+ def __getitem__(self, file_idx):
193
+ name = self.paths[file_idx]
194
+ data_buffer = self.data_buffer[name]
195
+ # check duration. if too short, then skip
196
+ if data_buffer['duration'] < (self.waveform_sec + 0.1):
197
+ return self.__getitem__( (file_idx + 1) % len(self.paths))
198
+
199
+ # get item
200
+ return self.get_data(name, data_buffer)
201
+
202
+ def get_data(self, name, data_buffer):
203
+ frame_resolution = self.hop_size / self.sample_rate
204
+ duration = data_buffer['duration']
205
+ waveform_sec = duration if self.whole_audio else self.waveform_sec
206
+
207
+ # load audio
208
+ idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1)
209
+ start_frame = int(idx_from / frame_resolution)
210
+ units_frame_len = int(waveform_sec / frame_resolution)
211
+ aug_flag = random.choice([True, False]) and self.use_aug
212
+ '''
213
+ audio = data_buffer.get('audio')
214
+ if audio is None:
215
+ path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
216
+ audio, sr = librosa.load(
217
+ path_audio,
218
+ sr = self.sample_rate,
219
+ offset = start_frame * frame_resolution,
220
+ duration = waveform_sec)
221
+ if len(audio.shape) > 1:
222
+ audio = librosa.to_mono(audio)
223
+ # clip audio into N seconds
224
+ audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size]
225
+ audio = torch.from_numpy(audio).float()
226
+ else:
227
+ audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size]
228
+ '''
229
+ # load mel
230
+ mel_key = 'aug_mel' if aug_flag else 'mel'
231
+ mel = data_buffer.get(mel_key)
232
+ if mel is None:
233
+ mel = os.path.join(self.path_root, mel_key, name) + '.npy'
234
+ mel = np.load(mel)
235
+ mel = mel[start_frame : start_frame + units_frame_len]
236
+ mel = torch.from_numpy(mel).float()
237
+ else:
238
+ mel = mel[start_frame : start_frame + units_frame_len]
239
+
240
+ # load units
241
+ units = data_buffer.get('units')
242
+ if units is None:
243
+ units = os.path.join(self.path_root, 'units', name) + '.npy'
244
+ units = np.load(units)
245
+ units = units[start_frame : start_frame + units_frame_len]
246
+ units = torch.from_numpy(units).float()
247
+ else:
248
+ units = units[start_frame : start_frame + units_frame_len]
249
+
250
+ # load f0
251
+ f0 = data_buffer.get('f0')
252
+ aug_shift = 0
253
+ if aug_flag:
254
+ aug_shift = self.pitch_aug_dict[name]
255
+ f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len]
256
+
257
+ # load volume
258
+ vol_key = 'aug_vol' if aug_flag else 'volume'
259
+ volume = data_buffer.get(vol_key)
260
+ volume_frames = volume[start_frame : start_frame + units_frame_len]
261
+
262
+ # load spk_id
263
+ spk_id = data_buffer.get('spk_id')
264
+
265
+ # load shift
266
+ aug_shift = torch.LongTensor(np.array([[aug_shift]]))
267
+
268
+ return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name)
269
+
270
+ def __len__(self):
271
+ return len(self.paths)
DDSP-SVC/diffusion/diffusion.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from functools import partial
3
+ from inspect import isfunction
4
+ import torch.nn.functional as F
5
+ import librosa.sequence
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from tqdm import tqdm
10
+
11
+
12
+ def exists(x):
13
+ return x is not None
14
+
15
+
16
+ def default(val, d):
17
+ if exists(val):
18
+ return val
19
+ return d() if isfunction(d) else d
20
+
21
+
22
+ def extract(a, t, x_shape):
23
+ b, *_ = t.shape
24
+ out = a.gather(-1, t)
25
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
26
+
27
+
28
+ def noise_like(shape, device, repeat=False):
29
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
30
+ noise = lambda: torch.randn(shape, device=device)
31
+ return repeat_noise() if repeat else noise()
32
+
33
+
34
+ def linear_beta_schedule(timesteps, max_beta=0.02):
35
+ """
36
+ linear schedule
37
+ """
38
+ betas = np.linspace(1e-4, max_beta, timesteps)
39
+ return betas
40
+
41
+
42
+ def cosine_beta_schedule(timesteps, s=0.008):
43
+ """
44
+ cosine schedule
45
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
46
+ """
47
+ steps = timesteps + 1
48
+ x = np.linspace(0, steps, steps)
49
+ alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
50
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
51
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
52
+ return np.clip(betas, a_min=0, a_max=0.999)
53
+
54
+
55
+ beta_schedule = {
56
+ "cosine": cosine_beta_schedule,
57
+ "linear": linear_beta_schedule,
58
+ }
59
+
60
+
61
+ class GaussianDiffusion(nn.Module):
62
+ def __init__(self,
63
+ denoise_fn,
64
+ out_dims=128,
65
+ timesteps=1000,
66
+ k_step=1000,
67
+ max_beta=0.02,
68
+ spec_min=-12,
69
+ spec_max=2):
70
+ super().__init__()
71
+ self.denoise_fn = denoise_fn
72
+ self.out_dims = out_dims
73
+ betas = beta_schedule['linear'](timesteps, max_beta=max_beta)
74
+
75
+ alphas = 1. - betas
76
+ alphas_cumprod = np.cumprod(alphas, axis=0)
77
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
78
+
79
+ timesteps, = betas.shape
80
+ self.num_timesteps = int(timesteps)
81
+ self.k_step = k_step
82
+
83
+ self.noise_list = deque(maxlen=4)
84
+
85
+ to_torch = partial(torch.tensor, dtype=torch.float32)
86
+
87
+ self.register_buffer('betas', to_torch(betas))
88
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
89
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
90
+
91
+ # calculations for diffusion q(x_t | x_{t-1}) and others
92
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
93
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
94
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
95
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
96
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
97
+
98
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
99
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
100
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
101
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
102
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
103
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
104
+ self.register_buffer('posterior_mean_coef1', to_torch(
105
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
106
+ self.register_buffer('posterior_mean_coef2', to_torch(
107
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
108
+
109
+ self.register_buffer('spec_min', torch.FloatTensor([spec_min])[None, None, :out_dims])
110
+ self.register_buffer('spec_max', torch.FloatTensor([spec_max])[None, None, :out_dims])
111
+
112
+ def q_mean_variance(self, x_start, t):
113
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
114
+ variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
115
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
116
+ return mean, variance, log_variance
117
+
118
+ def predict_start_from_noise(self, x_t, t, noise):
119
+ return (
120
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
121
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
122
+ )
123
+
124
+ def q_posterior(self, x_start, x_t, t):
125
+ posterior_mean = (
126
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
127
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
128
+ )
129
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
130
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
131
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
132
+
133
+ def p_mean_variance(self, x, t, cond):
134
+ noise_pred = self.denoise_fn(x, t, cond=cond)
135
+ x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
136
+
137
+ x_recon.clamp_(-1., 1.)
138
+
139
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
140
+ return model_mean, posterior_variance, posterior_log_variance
141
+
142
+ @torch.no_grad()
143
+ def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
144
+ b, *_, device = *x.shape, x.device
145
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond)
146
+ noise = noise_like(x.shape, device, repeat_noise)
147
+ # no noise when t == 0
148
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
149
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
150
+
151
+ @torch.no_grad()
152
+ def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
153
+ """
154
+ Use the PLMS method from
155
+ [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
156
+ """
157
+
158
+ def get_x_pred(x, noise_t, t):
159
+ a_t = extract(self.alphas_cumprod, t, x.shape)
160
+ a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape)
161
+ a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
162
+
163
+ x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (
164
+ a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
165
+ x_pred = x + x_delta
166
+
167
+ return x_pred
168
+
169
+ noise_list = self.noise_list
170
+ noise_pred = self.denoise_fn(x, t, cond=cond)
171
+
172
+ if len(noise_list) == 0:
173
+ x_pred = get_x_pred(x, noise_pred, t)
174
+ noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond)
175
+ noise_pred_prime = (noise_pred + noise_pred_prev) / 2
176
+ elif len(noise_list) == 1:
177
+ noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
178
+ elif len(noise_list) == 2:
179
+ noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
180
+ else:
181
+ noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
182
+
183
+ x_prev = get_x_pred(x, noise_pred_prime, t)
184
+ noise_list.append(noise_pred)
185
+
186
+ return x_prev
187
+
188
+ def q_sample(self, x_start, t, noise=None):
189
+ noise = default(noise, lambda: torch.randn_like(x_start))
190
+ return (
191
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
192
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
193
+ )
194
+
195
+ def p_losses(self, x_start, t, cond, noise=None, loss_type='l2'):
196
+ noise = default(noise, lambda: torch.randn_like(x_start))
197
+
198
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
199
+ x_recon = self.denoise_fn(x_noisy, t, cond)
200
+
201
+ if loss_type == 'l1':
202
+ loss = (noise - x_recon).abs().mean()
203
+ elif loss_type == 'l2':
204
+ loss = F.mse_loss(noise, x_recon)
205
+ else:
206
+ raise NotImplementedError()
207
+
208
+ return loss
209
+
210
+ def forward(self,
211
+ condition,
212
+ gt_spec=None,
213
+ infer=True,
214
+ infer_speedup=10,
215
+ method='dpm-solver',
216
+ k_step=300,
217
+ use_tqdm=True):
218
+ """
219
+ conditioning diffusion, use fastspeech2 encoder output as the condition
220
+ """
221
+ cond = condition.transpose(1, 2)
222
+ b, device = condition.shape[0], condition.device
223
+
224
+ if not infer:
225
+ spec = self.norm_spec(gt_spec)
226
+ t = torch.randint(0, self.k_step, (b,), device=device).long()
227
+ norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
228
+ return self.p_losses(norm_spec, t, cond=cond)
229
+ else:
230
+ shape = (cond.shape[0], 1, self.out_dims, cond.shape[2])
231
+
232
+ if gt_spec is None:
233
+ t = self.k_step
234
+ x = torch.randn(shape, device=device)
235
+ else:
236
+ t = k_step
237
+ norm_spec = self.norm_spec(gt_spec)
238
+ norm_spec = norm_spec.transpose(1, 2)[:, None, :, :]
239
+ x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long())
240
+
241
+ if method is not None and infer_speedup > 1:
242
+ if method == 'dpm-solver':
243
+ from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
244
+ # 1. Define the noise schedule.
245
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t])
246
+
247
+ # 2. Convert your discrete-time `model` to the continuous-time
248
+ # noise prediction model. Here is an example for a diffusion model
249
+ # `model` with the noise prediction type ("noise") .
250
+ def my_wrapper(fn):
251
+ def wrapped(x, t, **kwargs):
252
+ ret = fn(x, t, **kwargs)
253
+ if use_tqdm:
254
+ self.bar.update(1)
255
+ return ret
256
+
257
+ return wrapped
258
+
259
+ model_fn = model_wrapper(
260
+ my_wrapper(self.denoise_fn),
261
+ noise_schedule,
262
+ model_type="noise", # or "x_start" or "v" or "score"
263
+ model_kwargs={"cond": cond}
264
+ )
265
+
266
+ # 3. Define dpm-solver and sample by singlestep DPM-Solver.
267
+ # (We recommend singlestep DPM-Solver for unconditional sampling)
268
+ # You can adjust the `steps` to balance the computation
269
+ # costs and the sample quality.
270
+ dpm_solver = DPM_Solver(model_fn, noise_schedule)
271
+
272
+ steps = t // infer_speedup
273
+ if use_tqdm:
274
+ self.bar = tqdm(desc="sample time step", total=steps)
275
+ x = dpm_solver.sample(
276
+ x,
277
+ steps=steps,
278
+ order=3,
279
+ skip_type="time_uniform",
280
+ method="singlestep",
281
+ )
282
+ if use_tqdm:
283
+ self.bar.close()
284
+ elif method == 'pndm':
285
+ self.noise_list = deque(maxlen=4)
286
+ if use_tqdm:
287
+ for i in tqdm(
288
+ reversed(range(0, t, infer_speedup)), desc='sample time step',
289
+ total=t // infer_speedup,
290
+ ):
291
+ x = self.p_sample_plms(
292
+ x, torch.full((b,), i, device=device, dtype=torch.long),
293
+ infer_speedup, cond=cond
294
+ )
295
+ else:
296
+ for i in reversed(range(0, t, infer_speedup)):
297
+ x = self.p_sample_plms(
298
+ x, torch.full((b,), i, device=device, dtype=torch.long),
299
+ infer_speedup, cond=cond
300
+ )
301
+ else:
302
+ raise NotImplementedError(method)
303
+ else:
304
+ if use_tqdm:
305
+ for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
306
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
307
+ else:
308
+ for i in reversed(range(0, t)):
309
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
310
+ x = x.squeeze(1).transpose(1, 2) # [B, T, M]
311
+ return self.denorm_spec(x)
312
+
313
+ def norm_spec(self, x):
314
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
315
+
316
+ def denorm_spec(self, x):
317
+ return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
DDSP-SVC/diffusion/dpm_solver_pytorch.py ADDED
@@ -0,0 +1,1201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule='discrete',
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.,
14
+ ):
15
+ """Create a wrapper class for the forward SDE (VP type).
16
+
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+
22
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
23
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
24
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
25
+
26
+ log_alpha_t = self.marginal_log_mean_coeff(t)
27
+ sigma_t = self.marginal_std(t)
28
+ lambda_t = self.marginal_lambda(t)
29
+
30
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
31
+
32
+ t = self.inverse_lambda(lambda_t)
33
+
34
+ ===============================================================
35
+
36
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
37
+
38
+ 1. For discrete-time DPMs:
39
+
40
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
41
+ t_i = (i + 1) / N
42
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
43
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
44
+
45
+ Args:
46
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
47
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
48
+
49
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
50
+
51
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
52
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
53
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
54
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
55
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
56
+ and
57
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
58
+
59
+
60
+ 2. For continuous-time DPMs:
61
+
62
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
63
+ schedule are the default settings in DDPM and improved-DDPM:
64
+
65
+ Args:
66
+ beta_min: A `float` number. The smallest beta for the linear schedule.
67
+ beta_max: A `float` number. The largest beta for the linear schedule.
68
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
69
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
70
+ T: A `float` number. The ending time of the forward process.
71
+
72
+ ===============================================================
73
+
74
+ Args:
75
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
76
+ 'linear' or 'cosine' for continuous-time DPMs.
77
+ Returns:
78
+ A wrapper object of the forward SDE (VP type).
79
+
80
+ ===============================================================
81
+
82
+ Example:
83
+
84
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
85
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
86
+
87
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
88
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
89
+
90
+ # For continuous-time DPMs (VPSDE), linear schedule:
91
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
92
+
93
+ """
94
+
95
+ if schedule not in ['discrete', 'linear', 'cosine']:
96
+ raise ValueError(
97
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
98
+ schedule))
99
+
100
+ self.schedule = schedule
101
+ if schedule == 'discrete':
102
+ if betas is not None:
103
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
104
+ else:
105
+ assert alphas_cumprod is not None
106
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
107
+ self.total_N = len(log_alphas)
108
+ self.T = 1.
109
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
110
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
111
+ else:
112
+ self.total_N = 1000
113
+ self.beta_0 = continuous_beta_0
114
+ self.beta_1 = continuous_beta_1
115
+ self.cosine_s = 0.008
116
+ self.cosine_beta_max = 999.
117
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
118
+ 1. + self.cosine_s) / math.pi - self.cosine_s
119
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
120
+ self.schedule = schedule
121
+ if schedule == 'cosine':
122
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
123
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
124
+ self.T = 0.9946
125
+ else:
126
+ self.T = 1.
127
+
128
+ def marginal_log_mean_coeff(self, t):
129
+ """
130
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
131
+ """
132
+ if self.schedule == 'discrete':
133
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
134
+ self.log_alpha_array.to(t.device)).reshape((-1))
135
+ elif self.schedule == 'linear':
136
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
137
+ elif self.schedule == 'cosine':
138
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
139
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
140
+ return log_alpha_t
141
+
142
+ def marginal_alpha(self, t):
143
+ """
144
+ Compute alpha_t of a given continuous-time label t in [0, T].
145
+ """
146
+ return torch.exp(self.marginal_log_mean_coeff(t))
147
+
148
+ def marginal_std(self, t):
149
+ """
150
+ Compute sigma_t of a given continuous-time label t in [0, T].
151
+ """
152
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
153
+
154
+ def marginal_lambda(self, t):
155
+ """
156
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
157
+ """
158
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
159
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
160
+ return log_mean_coeff - log_std
161
+
162
+ def inverse_lambda(self, lamb):
163
+ """
164
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
165
+ """
166
+ if self.schedule == 'linear':
167
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
168
+ Delta = self.beta_0 ** 2 + tmp
169
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
170
+ elif self.schedule == 'discrete':
171
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
172
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
173
+ torch.flip(self.t_array.to(lamb.device), [1]))
174
+ return t.reshape((-1,))
175
+ else:
176
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
177
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
178
+ 1. + self.cosine_s) / math.pi - self.cosine_s
179
+ t = t_fn(log_alpha)
180
+ return t
181
+
182
+
183
+ def model_wrapper(
184
+ model,
185
+ noise_schedule,
186
+ model_type="noise",
187
+ model_kwargs={},
188
+ guidance_type="uncond",
189
+ condition=None,
190
+ unconditional_condition=None,
191
+ guidance_scale=1.,
192
+ classifier_fn=None,
193
+ classifier_kwargs={},
194
+ ):
195
+ """Create a wrapper function for the noise prediction model.
196
+
197
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
198
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
199
+
200
+ We support four types of the diffusion model by setting `model_type`:
201
+
202
+ 1. "noise": noise prediction model. (Trained by predicting noise).
203
+
204
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
205
+
206
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
207
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
208
+
209
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
210
+ arXiv preprint arXiv:2202.00512 (2022).
211
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
212
+ arXiv preprint arXiv:2210.02303 (2022).
213
+
214
+ 4. "score": marginal score function. (Trained by denoising score matching).
215
+ Note that the score function and the noise prediction model follows a simple relationship:
216
+ ```
217
+ noise(x_t, t) = -sigma_t * score(x_t, t)
218
+ ```
219
+
220
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
221
+ 1. "uncond": unconditional sampling by DPMs.
222
+ The input `model` has the following format:
223
+ ``
224
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
225
+ ``
226
+
227
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
228
+ The input `model` has the following format:
229
+ ``
230
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
231
+ ``
232
+
233
+ The input `classifier_fn` has the following format:
234
+ ``
235
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
236
+ ``
237
+
238
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
239
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
240
+
241
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
242
+ The input `model` has the following format:
243
+ ``
244
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
245
+ ``
246
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
247
+
248
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
249
+ arXiv preprint arXiv:2207.12598 (2022).
250
+
251
+
252
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
253
+ or continuous-time labels (i.e. epsilon to T).
254
+
255
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
256
+ ``
257
+ def model_fn(x, t_continuous) -> noise:
258
+ t_input = get_model_input_time(t_continuous)
259
+ return noise_pred(model, x, t_input, **model_kwargs)
260
+ ``
261
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
262
+
263
+ ===============================================================
264
+
265
+ Args:
266
+ model: A diffusion model with the corresponding format described above.
267
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
268
+ model_type: A `str`. The parameterization type of the diffusion model.
269
+ "noise" or "x_start" or "v" or "score".
270
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
271
+ guidance_type: A `str`. The type of the guidance for sampling.
272
+ "uncond" or "classifier" or "classifier-free".
273
+ condition: A pytorch tensor. The condition for the guided sampling.
274
+ Only used for "classifier" or "classifier-free" guidance type.
275
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
276
+ Only used for "classifier-free" guidance type.
277
+ guidance_scale: A `float`. The scale for the guided sampling.
278
+ classifier_fn: A classifier function. Only used for the classifier guidance.
279
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
280
+ Returns:
281
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
282
+ """
283
+
284
+ def get_model_input_time(t_continuous):
285
+ """
286
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
287
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
288
+ For continuous-time DPMs, we just use `t_continuous`.
289
+ """
290
+ if noise_schedule.schedule == 'discrete':
291
+ return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N
292
+ else:
293
+ return t_continuous
294
+
295
+ def noise_pred_fn(x, t_continuous, cond=None):
296
+ if t_continuous.reshape((-1,)).shape[0] == 1:
297
+ t_continuous = t_continuous.expand((x.shape[0]))
298
+ t_input = get_model_input_time(t_continuous)
299
+ if cond is None:
300
+ output = model(x, t_input, **model_kwargs)
301
+ else:
302
+ output = model(x, t_input, cond, **model_kwargs)
303
+ if model_type == "noise":
304
+ return output
305
+ elif model_type == "x_start":
306
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
307
+ dims = x.dim()
308
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
309
+ elif model_type == "v":
310
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
311
+ dims = x.dim()
312
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
313
+ elif model_type == "score":
314
+ sigma_t = noise_schedule.marginal_std(t_continuous)
315
+ dims = x.dim()
316
+ return -expand_dims(sigma_t, dims) * output
317
+
318
+ def cond_grad_fn(x, t_input):
319
+ """
320
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
321
+ """
322
+ with torch.enable_grad():
323
+ x_in = x.detach().requires_grad_(True)
324
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
325
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
326
+
327
+ def model_fn(x, t_continuous):
328
+ """
329
+ The noise predicition model function that is used for DPM-Solver.
330
+ """
331
+ if t_continuous.reshape((-1,)).shape[0] == 1:
332
+ t_continuous = t_continuous.expand((x.shape[0]))
333
+ if guidance_type == "uncond":
334
+ return noise_pred_fn(x, t_continuous)
335
+ elif guidance_type == "classifier":
336
+ assert classifier_fn is not None
337
+ t_input = get_model_input_time(t_continuous)
338
+ cond_grad = cond_grad_fn(x, t_input)
339
+ sigma_t = noise_schedule.marginal_std(t_continuous)
340
+ noise = noise_pred_fn(x, t_continuous)
341
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
342
+ elif guidance_type == "classifier-free":
343
+ if guidance_scale == 1. or unconditional_condition is None:
344
+ return noise_pred_fn(x, t_continuous, cond=condition)
345
+ else:
346
+ x_in = torch.cat([x] * 2)
347
+ t_in = torch.cat([t_continuous] * 2)
348
+ c_in = torch.cat([unconditional_condition, condition])
349
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
350
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
351
+
352
+ assert model_type in ["noise", "x_start", "v"]
353
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
354
+ return model_fn
355
+
356
+
357
+ class DPM_Solver:
358
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
359
+ """Construct a DPM-Solver.
360
+
361
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
362
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
363
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
364
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
365
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
366
+
367
+ Args:
368
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
369
+ ``
370
+ def model_fn(x, t_continuous):
371
+ return noise
372
+ ``
373
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
374
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
375
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
376
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
377
+
378
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
379
+ """
380
+ self.model = model_fn
381
+ self.noise_schedule = noise_schedule
382
+ self.predict_x0 = predict_x0
383
+ self.thresholding = thresholding
384
+ self.max_val = max_val
385
+
386
+ def noise_prediction_fn(self, x, t):
387
+ """
388
+ Return the noise prediction model.
389
+ """
390
+ return self.model(x, t)
391
+
392
+ def data_prediction_fn(self, x, t):
393
+ """
394
+ Return the data prediction model (with thresholding).
395
+ """
396
+ noise = self.noise_prediction_fn(x, t)
397
+ dims = x.dim()
398
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
399
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
400
+ if self.thresholding:
401
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
402
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
403
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
404
+ x0 = torch.clamp(x0, -s, s) / s
405
+ return x0
406
+
407
+ def model_fn(self, x, t):
408
+ """
409
+ Convert the model to the noise prediction model or the data prediction model.
410
+ """
411
+ if self.predict_x0:
412
+ return self.data_prediction_fn(x, t)
413
+ else:
414
+ return self.noise_prediction_fn(x, t)
415
+
416
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
417
+ """Compute the intermediate time steps for sampling.
418
+
419
+ Args:
420
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
421
+ - 'logSNR': uniform logSNR for the time steps.
422
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
423
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
424
+ t_T: A `float`. The starting time of the sampling (default is T).
425
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
426
+ N: A `int`. The total number of the spacing of the time steps.
427
+ device: A torch device.
428
+ Returns:
429
+ A pytorch tensor of the time steps, with the shape (N + 1,).
430
+ """
431
+ if skip_type == 'logSNR':
432
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
433
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
434
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
435
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
436
+ elif skip_type == 'time_uniform':
437
+ return torch.linspace(t_T, t_0, N + 1).to(device)
438
+ elif skip_type == 'time_quadratic':
439
+ t_order = 2
440
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
441
+ return t
442
+ else:
443
+ raise ValueError(
444
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
445
+
446
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
447
+ """
448
+ Get the order of each step for sampling by the singlestep DPM-Solver.
449
+
450
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
451
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
452
+ - If order == 1:
453
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
454
+ - If order == 2:
455
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
456
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
457
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
458
+ - If order == 3:
459
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
460
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
461
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
462
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
463
+
464
+ ============================================
465
+ Args:
466
+ order: A `int`. The max order for the solver (2 or 3).
467
+ steps: A `int`. The total number of function evaluations (NFE).
468
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
469
+ - 'logSNR': uniform logSNR for the time steps.
470
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
471
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
472
+ t_T: A `float`. The starting time of the sampling (default is T).
473
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
474
+ device: A torch device.
475
+ Returns:
476
+ orders: A list of the solver order of each step.
477
+ """
478
+ if order == 3:
479
+ K = steps // 3 + 1
480
+ if steps % 3 == 0:
481
+ orders = [3, ] * (K - 2) + [2, 1]
482
+ elif steps % 3 == 1:
483
+ orders = [3, ] * (K - 1) + [1]
484
+ else:
485
+ orders = [3, ] * (K - 1) + [2]
486
+ elif order == 2:
487
+ if steps % 2 == 0:
488
+ K = steps // 2
489
+ orders = [2, ] * K
490
+ else:
491
+ K = steps // 2 + 1
492
+ orders = [2, ] * (K - 1) + [1]
493
+ elif order == 1:
494
+ K = 1
495
+ orders = [1, ] * steps
496
+ else:
497
+ raise ValueError("'order' must be '1' or '2' or '3'.")
498
+ if skip_type == 'logSNR':
499
+ # To reproduce the results in DPM-Solver paper
500
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
501
+ else:
502
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
503
+ torch.cumsum(torch.tensor([0, ] + orders), dim=0).to(device)]
504
+ return timesteps_outer, orders
505
+
506
+ def denoise_fn(self, x, s):
507
+ """
508
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
509
+ """
510
+ return self.data_prediction_fn(x, s)
511
+
512
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
513
+ """
514
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
515
+
516
+ Args:
517
+ x: A pytorch tensor. The initial value at time `s`.
518
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
519
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
520
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
521
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
522
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
523
+ Returns:
524
+ x_t: A pytorch tensor. The approximated solution at time `t`.
525
+ """
526
+ ns = self.noise_schedule
527
+ dims = x.dim()
528
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
529
+ h = lambda_t - lambda_s
530
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
531
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
532
+ alpha_t = torch.exp(log_alpha_t)
533
+
534
+ if self.predict_x0:
535
+ phi_1 = torch.expm1(-h)
536
+ if model_s is None:
537
+ model_s = self.model_fn(x, s)
538
+ x_t = (
539
+ expand_dims(sigma_t / sigma_s, dims) * x
540
+ - expand_dims(alpha_t * phi_1, dims) * model_s
541
+ )
542
+ if return_intermediate:
543
+ return x_t, {'model_s': model_s}
544
+ else:
545
+ return x_t
546
+ else:
547
+ phi_1 = torch.expm1(h)
548
+ if model_s is None:
549
+ model_s = self.model_fn(x, s)
550
+ x_t = (
551
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
552
+ - expand_dims(sigma_t * phi_1, dims) * model_s
553
+ )
554
+ if return_intermediate:
555
+ return x_t, {'model_s': model_s}
556
+ else:
557
+ return x_t
558
+
559
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
560
+ solver_type='dpm_solver'):
561
+ """
562
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
563
+
564
+ Args:
565
+ x: A pytorch tensor. The initial value at time `s`.
566
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
567
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
568
+ r1: A `float`. The hyperparameter of the second-order solver.
569
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
570
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
571
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
572
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
573
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
574
+ Returns:
575
+ x_t: A pytorch tensor. The approximated solution at time `t`.
576
+ """
577
+ if solver_type not in ['dpm_solver', 'taylor']:
578
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
579
+ if r1 is None:
580
+ r1 = 0.5
581
+ ns = self.noise_schedule
582
+ dims = x.dim()
583
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
584
+ h = lambda_t - lambda_s
585
+ lambda_s1 = lambda_s + r1 * h
586
+ s1 = ns.inverse_lambda(lambda_s1)
587
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
588
+ s1), ns.marginal_log_mean_coeff(t)
589
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
590
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
591
+
592
+ if self.predict_x0:
593
+ phi_11 = torch.expm1(-r1 * h)
594
+ phi_1 = torch.expm1(-h)
595
+
596
+ if model_s is None:
597
+ model_s = self.model_fn(x, s)
598
+ x_s1 = (
599
+ expand_dims(sigma_s1 / sigma_s, dims) * x
600
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
601
+ )
602
+ model_s1 = self.model_fn(x_s1, s1)
603
+ if solver_type == 'dpm_solver':
604
+ x_t = (
605
+ expand_dims(sigma_t / sigma_s, dims) * x
606
+ - expand_dims(alpha_t * phi_1, dims) * model_s
607
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
608
+ )
609
+ elif solver_type == 'taylor':
610
+ x_t = (
611
+ expand_dims(sigma_t / sigma_s, dims) * x
612
+ - expand_dims(alpha_t * phi_1, dims) * model_s
613
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
614
+ model_s1 - model_s)
615
+ )
616
+ else:
617
+ phi_11 = torch.expm1(r1 * h)
618
+ phi_1 = torch.expm1(h)
619
+
620
+ if model_s is None:
621
+ model_s = self.model_fn(x, s)
622
+ x_s1 = (
623
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
624
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
625
+ )
626
+ model_s1 = self.model_fn(x_s1, s1)
627
+ if solver_type == 'dpm_solver':
628
+ x_t = (
629
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
630
+ - expand_dims(sigma_t * phi_1, dims) * model_s
631
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
632
+ )
633
+ elif solver_type == 'taylor':
634
+ x_t = (
635
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
636
+ - expand_dims(sigma_t * phi_1, dims) * model_s
637
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
638
+ )
639
+ if return_intermediate:
640
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
641
+ else:
642
+ return x_t
643
+
644
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
645
+ return_intermediate=False, solver_type='dpm_solver'):
646
+ """
647
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
648
+
649
+ Args:
650
+ x: A pytorch tensor. The initial value at time `s`.
651
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
652
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
653
+ r1: A `float`. The hyperparameter of the third-order solver.
654
+ r2: A `float`. The hyperparameter of the third-order solver.
655
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
656
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
657
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
658
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
659
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
660
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
661
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
662
+ Returns:
663
+ x_t: A pytorch tensor. The approximated solution at time `t`.
664
+ """
665
+ if solver_type not in ['dpm_solver', 'taylor']:
666
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
667
+ if r1 is None:
668
+ r1 = 1. / 3.
669
+ if r2 is None:
670
+ r2 = 2. / 3.
671
+ ns = self.noise_schedule
672
+ dims = x.dim()
673
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
674
+ h = lambda_t - lambda_s
675
+ lambda_s1 = lambda_s + r1 * h
676
+ lambda_s2 = lambda_s + r2 * h
677
+ s1 = ns.inverse_lambda(lambda_s1)
678
+ s2 = ns.inverse_lambda(lambda_s2)
679
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
680
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
681
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
682
+ s2), ns.marginal_std(t)
683
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
684
+
685
+ if self.predict_x0:
686
+ phi_11 = torch.expm1(-r1 * h)
687
+ phi_12 = torch.expm1(-r2 * h)
688
+ phi_1 = torch.expm1(-h)
689
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
690
+ phi_2 = phi_1 / h + 1.
691
+ phi_3 = phi_2 / h - 0.5
692
+
693
+ if model_s is None:
694
+ model_s = self.model_fn(x, s)
695
+ if model_s1 is None:
696
+ x_s1 = (
697
+ expand_dims(sigma_s1 / sigma_s, dims) * x
698
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
699
+ )
700
+ model_s1 = self.model_fn(x_s1, s1)
701
+ x_s2 = (
702
+ expand_dims(sigma_s2 / sigma_s, dims) * x
703
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
704
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
705
+ )
706
+ model_s2 = self.model_fn(x_s2, s2)
707
+ if solver_type == 'dpm_solver':
708
+ x_t = (
709
+ expand_dims(sigma_t / sigma_s, dims) * x
710
+ - expand_dims(alpha_t * phi_1, dims) * model_s
711
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
712
+ )
713
+ elif solver_type == 'taylor':
714
+ D1_0 = (1. / r1) * (model_s1 - model_s)
715
+ D1_1 = (1. / r2) * (model_s2 - model_s)
716
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
717
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
718
+ x_t = (
719
+ expand_dims(sigma_t / sigma_s, dims) * x
720
+ - expand_dims(alpha_t * phi_1, dims) * model_s
721
+ + expand_dims(alpha_t * phi_2, dims) * D1
722
+ - expand_dims(alpha_t * phi_3, dims) * D2
723
+ )
724
+ else:
725
+ phi_11 = torch.expm1(r1 * h)
726
+ phi_12 = torch.expm1(r2 * h)
727
+ phi_1 = torch.expm1(h)
728
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
729
+ phi_2 = phi_1 / h - 1.
730
+ phi_3 = phi_2 / h - 0.5
731
+
732
+ if model_s is None:
733
+ model_s = self.model_fn(x, s)
734
+ if model_s1 is None:
735
+ x_s1 = (
736
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
737
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
738
+ )
739
+ model_s1 = self.model_fn(x_s1, s1)
740
+ x_s2 = (
741
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
742
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
743
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
744
+ )
745
+ model_s2 = self.model_fn(x_s2, s2)
746
+ if solver_type == 'dpm_solver':
747
+ x_t = (
748
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
749
+ - expand_dims(sigma_t * phi_1, dims) * model_s
750
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
751
+ )
752
+ elif solver_type == 'taylor':
753
+ D1_0 = (1. / r1) * (model_s1 - model_s)
754
+ D1_1 = (1. / r2) * (model_s2 - model_s)
755
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
756
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
757
+ x_t = (
758
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
759
+ - expand_dims(sigma_t * phi_1, dims) * model_s
760
+ - expand_dims(sigma_t * phi_2, dims) * D1
761
+ - expand_dims(sigma_t * phi_3, dims) * D2
762
+ )
763
+
764
+ if return_intermediate:
765
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
766
+ else:
767
+ return x_t
768
+
769
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
770
+ """
771
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
772
+
773
+ Args:
774
+ x: A pytorch tensor. The initial value at time `s`.
775
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
776
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
777
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
778
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
779
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
780
+ Returns:
781
+ x_t: A pytorch tensor. The approximated solution at time `t`.
782
+ """
783
+ if solver_type not in ['dpm_solver', 'taylor']:
784
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
785
+ ns = self.noise_schedule
786
+ dims = x.dim()
787
+ model_prev_1, model_prev_0 = model_prev_list
788
+ t_prev_1, t_prev_0 = t_prev_list
789
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
790
+ t_prev_0), ns.marginal_lambda(t)
791
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
792
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
793
+ alpha_t = torch.exp(log_alpha_t)
794
+
795
+ h_0 = lambda_prev_0 - lambda_prev_1
796
+ h = lambda_t - lambda_prev_0
797
+ r0 = h_0 / h
798
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
799
+ if self.predict_x0:
800
+ if solver_type == 'dpm_solver':
801
+ x_t = (
802
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
803
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
804
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
805
+ )
806
+ elif solver_type == 'taylor':
807
+ x_t = (
808
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
809
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
810
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
811
+ )
812
+ else:
813
+ if solver_type == 'dpm_solver':
814
+ x_t = (
815
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
816
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
817
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
818
+ )
819
+ elif solver_type == 'taylor':
820
+ x_t = (
821
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
822
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
823
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
824
+ )
825
+ return x_t
826
+
827
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
828
+ """
829
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
830
+
831
+ Args:
832
+ x: A pytorch tensor. The initial value at time `s`.
833
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
834
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
835
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
836
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
837
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
838
+ Returns:
839
+ x_t: A pytorch tensor. The approximated solution at time `t`.
840
+ """
841
+ ns = self.noise_schedule
842
+ dims = x.dim()
843
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
844
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
845
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
846
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
847
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
848
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
849
+ alpha_t = torch.exp(log_alpha_t)
850
+
851
+ h_1 = lambda_prev_1 - lambda_prev_2
852
+ h_0 = lambda_prev_0 - lambda_prev_1
853
+ h = lambda_t - lambda_prev_0
854
+ r0, r1 = h_0 / h, h_1 / h
855
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
856
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
857
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
858
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
859
+ if self.predict_x0:
860
+ x_t = (
861
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
862
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
863
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
864
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
865
+ )
866
+ else:
867
+ x_t = (
868
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
869
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
870
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
871
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
872
+ )
873
+ return x_t
874
+
875
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
876
+ r2=None):
877
+ """
878
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
879
+
880
+ Args:
881
+ x: A pytorch tensor. The initial value at time `s`.
882
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
883
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
884
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
885
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
886
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
887
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
888
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
889
+ r2: A `float`. The hyperparameter of the third-order solver.
890
+ Returns:
891
+ x_t: A pytorch tensor. The approximated solution at time `t`.
892
+ """
893
+ if order == 1:
894
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
895
+ elif order == 2:
896
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
897
+ solver_type=solver_type, r1=r1)
898
+ elif order == 3:
899
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
900
+ solver_type=solver_type, r1=r1, r2=r2)
901
+ else:
902
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
903
+
904
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
905
+ """
906
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
907
+
908
+ Args:
909
+ x: A pytorch tensor. The initial value at time `s`.
910
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
911
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
912
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
913
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
914
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
915
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
916
+ Returns:
917
+ x_t: A pytorch tensor. The approximated solution at time `t`.
918
+ """
919
+ if order == 1:
920
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
921
+ elif order == 2:
922
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
923
+ elif order == 3:
924
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
925
+ else:
926
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
927
+
928
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
929
+ solver_type='dpm_solver'):
930
+ """
931
+ The adaptive step size solver based on singlestep DPM-Solver.
932
+
933
+ Args:
934
+ x: A pytorch tensor. The initial value at time `t_T`.
935
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
936
+ t_T: A `float`. The starting time of the sampling (default is T).
937
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
938
+ h_init: A `float`. The initial step size (for logSNR).
939
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
940
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
941
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
942
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
943
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
944
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
945
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
946
+ Returns:
947
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
948
+
949
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
950
+ """
951
+ ns = self.noise_schedule
952
+ s = t_T * torch.ones((x.shape[0],)).to(x)
953
+ lambda_s = ns.marginal_lambda(s)
954
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
955
+ h = h_init * torch.ones_like(s).to(x)
956
+ x_prev = x
957
+ nfe = 0
958
+ if order == 2:
959
+ r1 = 0.5
960
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
961
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
962
+ solver_type=solver_type,
963
+ **kwargs)
964
+ elif order == 3:
965
+ r1, r2 = 1. / 3., 2. / 3.
966
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
967
+ return_intermediate=True,
968
+ solver_type=solver_type)
969
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
970
+ solver_type=solver_type,
971
+ **kwargs)
972
+ else:
973
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
974
+ while torch.abs((s - t_0)).mean() > t_err:
975
+ t = ns.inverse_lambda(lambda_s + h)
976
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
977
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
978
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
979
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
980
+ E = norm_fn((x_higher - x_lower) / delta).max()
981
+ if torch.all(E <= 1.):
982
+ x = x_higher
983
+ s = t
984
+ x_prev = x_lower
985
+ lambda_s = ns.marginal_lambda(s)
986
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
987
+ nfe += order
988
+ print('adaptive solver nfe', nfe)
989
+ return x
990
+
991
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
992
+ method='singlestep', denoise=False, solver_type='dpm_solver', atol=0.0078,
993
+ rtol=0.05,
994
+ ):
995
+ """
996
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
997
+
998
+ =====================================================
999
+
1000
+ We support the following algorithms for both noise prediction model and data prediction model:
1001
+ - 'singlestep':
1002
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1003
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1004
+ The total number of function evaluations (NFE) == `steps`.
1005
+ Given a fixed NFE == `steps`, the sampling procedure is:
1006
+ - If `order` == 1:
1007
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1008
+ - If `order` == 2:
1009
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1010
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1011
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1012
+ - If `order` == 3:
1013
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1014
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1015
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1016
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1017
+ - 'multistep':
1018
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1019
+ We initialize the first `order` values by lower order multistep solvers.
1020
+ Given a fixed NFE == `steps`, the sampling procedure is:
1021
+ Denote K = steps.
1022
+ - If `order` == 1:
1023
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1024
+ - If `order` == 2:
1025
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1026
+ - If `order` == 3:
1027
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1028
+ - 'singlestep_fixed':
1029
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1030
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1031
+ - 'adaptive':
1032
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1033
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1034
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1035
+ (NFE) and the sample quality.
1036
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1037
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1038
+
1039
+ =====================================================
1040
+
1041
+ Some advices for choosing the algorithm:
1042
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1043
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
1044
+ e.g.
1045
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
1046
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1047
+ skip_type='time_uniform', method='singlestep')
1048
+ - For **guided sampling with large guidance scale** by DPMs:
1049
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
1050
+ e.g.
1051
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
1052
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1053
+ skip_type='time_uniform', method='multistep')
1054
+
1055
+ We support three types of `skip_type`:
1056
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1057
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1058
+ - 'time_quadratic': quadratic time for the time steps.
1059
+
1060
+ =====================================================
1061
+ Args:
1062
+ x: A pytorch tensor. The initial value at time `t_start`
1063
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1064
+ steps: A `int`. The total number of function evaluations (NFE).
1065
+ t_start: A `float`. The starting time of the sampling.
1066
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1067
+ t_end: A `float`. The ending time of the sampling.
1068
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1069
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1070
+ For discrete-time DPMs:
1071
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1072
+ For continuous-time DPMs:
1073
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1074
+ order: A `int`. The order of DPM-Solver.
1075
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1076
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1077
+ denoise: A `bool`. Whether to denoise at the final step. Default is False.
1078
+ If `denoise` is True, the total NFE is (`steps` + 1).
1079
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1080
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1081
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1082
+ Returns:
1083
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1084
+
1085
+ """
1086
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1087
+ t_T = self.noise_schedule.T if t_start is None else t_start
1088
+ device = x.device
1089
+ if method == 'adaptive':
1090
+ with torch.no_grad():
1091
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1092
+ solver_type=solver_type)
1093
+ elif method == 'multistep':
1094
+ assert steps >= order
1095
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1096
+ assert timesteps.shape[0] - 1 == steps
1097
+ with torch.no_grad():
1098
+ vec_t = timesteps[0].expand((x.shape[0]))
1099
+ model_prev_list = [self.model_fn(x, vec_t)]
1100
+ t_prev_list = [vec_t]
1101
+ # Init the first `order` values by lower order multistep DPM-Solver.
1102
+ for init_order in range(1, order):
1103
+ vec_t = timesteps[init_order].expand(x.shape[0])
1104
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1105
+ solver_type=solver_type)
1106
+ model_prev_list.append(self.model_fn(x, vec_t))
1107
+ t_prev_list.append(vec_t)
1108
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1109
+ for step in range(order, steps + 1):
1110
+ vec_t = timesteps[step].expand(x.shape[0])
1111
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, order,
1112
+ solver_type=solver_type)
1113
+ for i in range(order - 1):
1114
+ t_prev_list[i] = t_prev_list[i + 1]
1115
+ model_prev_list[i] = model_prev_list[i + 1]
1116
+ t_prev_list[-1] = vec_t
1117
+ # We do not need to evaluate the final model value.
1118
+ if step < steps:
1119
+ model_prev_list[-1] = self.model_fn(x, vec_t)
1120
+ elif method in ['singlestep', 'singlestep_fixed']:
1121
+ if method == 'singlestep':
1122
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
1123
+ skip_type=skip_type,
1124
+ t_T=t_T, t_0=t_0,
1125
+ device=device)
1126
+ elif method == 'singlestep_fixed':
1127
+ K = steps // order
1128
+ orders = [order, ] * K
1129
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1130
+ for i, order in enumerate(orders):
1131
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1132
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1133
+ N=order, device=device)
1134
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1135
+ vec_s, vec_t = t_T_inner.repeat(x.shape[0]), t_0_inner.repeat(x.shape[0])
1136
+ h = lambda_inner[-1] - lambda_inner[0]
1137
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1138
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1139
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1140
+ if denoise:
1141
+ x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1142
+ return x
1143
+
1144
+
1145
+ #############################################################
1146
+ # other utility functions
1147
+ #############################################################
1148
+
1149
+ def interpolate_fn(x, xp, yp):
1150
+ """
1151
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1152
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1153
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1154
+
1155
+ Args:
1156
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1157
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1158
+ yp: PyTorch tensor with shape [C, K].
1159
+ Returns:
1160
+ The function values f(x), with shape [N, C].
1161
+ """
1162
+ N, K = x.shape[0], xp.shape[1]
1163
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1164
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1165
+ x_idx = torch.argmin(x_indices, dim=2)
1166
+ cand_start_idx = x_idx - 1
1167
+ start_idx = torch.where(
1168
+ torch.eq(x_idx, 0),
1169
+ torch.tensor(1, device=x.device),
1170
+ torch.where(
1171
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1172
+ ),
1173
+ )
1174
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1175
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1176
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1177
+ start_idx2 = torch.where(
1178
+ torch.eq(x_idx, 0),
1179
+ torch.tensor(0, device=x.device),
1180
+ torch.where(
1181
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1182
+ ),
1183
+ )
1184
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1185
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1186
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1187
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1188
+ return cand
1189
+
1190
+
1191
+ def expand_dims(v, dims):
1192
+ """
1193
+ Expand the tensor `v` to the dim `dims`.
1194
+
1195
+ Args:
1196
+ `v`: a PyTorch tensor with shape [N].
1197
+ `dim`: a `int`.
1198
+ Returns:
1199
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1200
+ """
1201
+ return v[(...,) + (None,) * (dims - 1)]
DDSP-SVC/diffusion/infer_gt_mel.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from diffusion.unit2mel import load_model_vocoder
5
+
6
+
7
+ class DiffGtMel:
8
+ def __init__(self, project_path=None, device=None):
9
+ self.project_path = project_path
10
+ if device is not None:
11
+ self.device = device
12
+ else:
13
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+ self.model = None
15
+ self.vocoder = None
16
+ self.args = None
17
+
18
+ def flush_model(self, project_path, ddsp_config=None):
19
+ if (self.model is None) or (project_path != self.project_path):
20
+ model, vocoder, args = load_model_vocoder(project_path, device=self.device)
21
+ if self.check_args(ddsp_config, args):
22
+ self.model = model
23
+ self.vocoder = vocoder
24
+ self.args = args
25
+
26
+ def check_args(self, args1, args2):
27
+ if args1.data.block_size != args2.data.block_size:
28
+ raise ValueError("DDSP与DIFF模型的block_size不一致")
29
+ if args1.data.sampling_rate != args2.data.sampling_rate:
30
+ raise ValueError("DDSP与DIFF模型的sampling_rate不一致")
31
+ if args1.data.encoder != args2.data.encoder:
32
+ raise ValueError("DDSP与DIFF模型的encoder不一致")
33
+ return True
34
+
35
+ def __call__(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, use_dpm=True,
36
+ spk_mix_dict=None, start_frame=0):
37
+ input_mel = self.vocoder.extract(audio, self.args.data.sampling_rate)
38
+ if use_dpm:
39
+ method = 'dpm-solver'
40
+ else:
41
+ method = 'pndm'
42
+ out_mel = self.model(
43
+ hubert,
44
+ f0,
45
+ volume,
46
+ spk_id=spk_id,
47
+ spk_mix_dict=spk_mix_dict,
48
+ gt_spec=input_mel,
49
+ infer=True,
50
+ infer_speedup=acc,
51
+ method=method,
52
+ k_step=k_step,
53
+ use_tqdm=False)
54
+ if start_frame > 0:
55
+ out_mel = out_mel[:, start_frame:, :]
56
+ f0 = f0[:, start_frame:, :]
57
+ output = self.vocoder.infer(out_mel, f0)
58
+ if start_frame > 0:
59
+ output = F.pad(output, (start_frame * self.vocoder.vocoder_hop_size, 0))
60
+ return output
61
+
62
+ def infer(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, use_dpm=True, silence_front=0,
63
+ use_silence=False, spk_mix_dict=None):
64
+ start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size)
65
+ if use_silence:
66
+ audio = audio[:, start_frame * self.vocoder.vocoder_hop_size:]
67
+ f0 = f0[:, start_frame:, :]
68
+ hubert = hubert[:, start_frame:, :]
69
+ volume = volume[:, start_frame:, :]
70
+ _start_frame = 0
71
+ else:
72
+ _start_frame = start_frame
73
+ audio = self.__call__(audio, f0, hubert, volume, acc=acc, spk_id=spk_id, k_step=k_step,
74
+ use_dpm=use_dpm, spk_mix_dict=spk_mix_dict, start_frame=_start_frame)
75
+ if use_silence:
76
+ if start_frame > 0:
77
+ audio = F.pad(audio, (start_frame * self.vocoder.vocoder_hop_size, 0))
78
+ return audio
DDSP-SVC/diffusion/solver.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ import librosa
6
+ from logger.saver import Saver
7
+ from logger import utils
8
+
9
+ def test(args, model, vocoder, loader_test, saver):
10
+ print(' [*] testing...')
11
+ model.eval()
12
+
13
+ # losses
14
+ test_loss = 0.
15
+
16
+ # intialization
17
+ num_batches = len(loader_test)
18
+ rtf_all = []
19
+
20
+ # run
21
+ with torch.no_grad():
22
+ for bidx, data in enumerate(loader_test):
23
+ fn = data['name'][0]
24
+ print('--------')
25
+ print('{}/{} - {}'.format(bidx, num_batches, fn))
26
+
27
+ # unpack data
28
+ for k in data.keys():
29
+ if k != 'name':
30
+ data[k] = data[k].to(args.device)
31
+ print('>>', data['name'][0])
32
+
33
+ # forward
34
+ st_time = time.time()
35
+ mel = model(
36
+ data['units'],
37
+ data['f0'],
38
+ data['volume'],
39
+ data['spk_id'],
40
+ gt_spec=None,
41
+ infer=True,
42
+ infer_speedup=args.infer.speedup,
43
+ method=args.infer.method)
44
+ signal = vocoder.infer(mel, data['f0'])
45
+ ed_time = time.time()
46
+
47
+ # RTF
48
+ run_time = ed_time - st_time
49
+ song_time = signal.shape[-1] / args.data.sampling_rate
50
+ rtf = run_time / song_time
51
+ print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
52
+ rtf_all.append(rtf)
53
+
54
+ # loss
55
+ for i in range(args.train.batch_size):
56
+ loss = model(
57
+ data['units'],
58
+ data['f0'],
59
+ data['volume'],
60
+ data['spk_id'],
61
+ gt_spec=data['mel'],
62
+ infer=False)
63
+ test_loss += loss.item()
64
+
65
+ # log mel
66
+ saver.log_spec(data['name'][0], data['mel'], mel)
67
+
68
+ # log audio
69
+ path_audio = os.path.join(args.data.valid_path, 'audio', data['name'][0]) + '.wav'
70
+ audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate)
71
+ if len(audio.shape) > 1:
72
+ audio = librosa.to_mono(audio)
73
+ audio = torch.from_numpy(audio).unsqueeze(0).to(signal)
74
+ saver.log_audio({fn+'/gt.wav': audio, fn+'/pred.wav': signal})
75
+
76
+ # report
77
+ test_loss /= args.train.batch_size
78
+ test_loss /= num_batches
79
+
80
+ # check
81
+ print(' [test_loss] test_loss:', test_loss)
82
+ print(' Real Time Factor', np.mean(rtf_all))
83
+ return test_loss
84
+
85
+
86
+ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test):
87
+ # saver
88
+ saver = Saver(args, initial_global_step=initial_global_step)
89
+
90
+ # model size
91
+ params_count = utils.get_network_paras_amount({'model': model})
92
+ saver.log_info('--- model size ---')
93
+ saver.log_info(params_count)
94
+
95
+ # run
96
+ num_batches = len(loader_train)
97
+ model.train()
98
+ saver.log_info('======= start training =======')
99
+ for epoch in range(args.train.epochs):
100
+ for batch_idx, data in enumerate(loader_train):
101
+ saver.global_step_increment()
102
+ optimizer.zero_grad()
103
+
104
+ # unpack data
105
+ for k in data.keys():
106
+ if k != 'name':
107
+ data[k] = data[k].to(args.device)
108
+
109
+ # forward
110
+ loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
111
+ aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False)
112
+
113
+ # handle nan loss
114
+ if torch.isnan(loss):
115
+ raise ValueError(' [x] nan loss ')
116
+ else:
117
+ # backpropagate
118
+ loss.backward()
119
+ optimizer.step()
120
+ scheduler.step()
121
+
122
+ # log loss
123
+ if saver.global_step % args.train.interval_log == 0:
124
+ current_lr = optimizer.param_groups[0]['lr']
125
+ saver.log_info(
126
+ 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format(
127
+ epoch,
128
+ batch_idx,
129
+ num_batches,
130
+ args.env.expdir,
131
+ args.train.interval_log/saver.get_interval_time(),
132
+ current_lr,
133
+ loss.item(),
134
+ saver.get_total_time(),
135
+ saver.global_step
136
+ )
137
+ )
138
+
139
+ saver.log_value({
140
+ 'train/loss': loss.item()
141
+ })
142
+
143
+ saver.log_value({
144
+ 'train/lr': current_lr
145
+ })
146
+
147
+ # validation
148
+ if saver.global_step % args.train.interval_val == 0:
149
+ # save latest
150
+ saver.save_model(model, optimizer, postfix=f'{saver.global_step}')
151
+ last_val_step = saver.global_step - args.train.interval_val
152
+ if last_val_step % args.train.interval_force_save != 0:
153
+ saver.delete_model(postfix=f'{last_val_step}')
154
+
155
+ # run testing set
156
+
157
+ test_loss = test(args, model, vocoder, loader_test, saver)
158
+
159
+ saver.log_info(
160
+ ' --- <validation> --- \nloss: {:.3f}. '.format(
161
+ test_loss,
162
+ )
163
+ )
164
+
165
+ saver.log_value({
166
+ 'validation/loss': test_loss
167
+ })
168
+
169
+ model.train()
170
+
171
+
DDSP-SVC/diffusion/unit2mel.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from .diffusion import GaussianDiffusion
7
+ from .wavenet import WaveNet
8
+ from .vocoder import Vocoder
9
+
10
+ class DotDict(dict):
11
+ def __getattr__(*args):
12
+ val = dict.get(*args)
13
+ return DotDict(val) if type(val) is dict else val
14
+
15
+ __setattr__ = dict.__setitem__
16
+ __delattr__ = dict.__delitem__
17
+
18
+
19
+ def load_model_vocoder(
20
+ model_path,
21
+ device='cpu'):
22
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
23
+ with open(config_file, "r") as config:
24
+ args = yaml.safe_load(config)
25
+ args = DotDict(args)
26
+
27
+ # load vocoder
28
+ vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device)
29
+
30
+ # load model
31
+ model = Unit2Mel(
32
+ args.data.encoder_out_channels,
33
+ args.model.n_spk,
34
+ args.model.use_pitch_aug,
35
+ vocoder.dimension,
36
+ args.model.n_layers,
37
+ args.model.n_chans,
38
+ args.model.n_hidden)
39
+
40
+ print(' [Loading] ' + model_path)
41
+ ckpt = torch.load(model_path, map_location=torch.device(device))
42
+ model.to(device)
43
+ model.load_state_dict(ckpt['model'])
44
+ model.eval()
45
+ return model, vocoder, args
46
+
47
+
48
+ class Unit2Mel(nn.Module):
49
+ def __init__(
50
+ self,
51
+ input_channel,
52
+ n_spk,
53
+ use_pitch_aug=False,
54
+ out_dims=128,
55
+ n_layers=20,
56
+ n_chans=384,
57
+ n_hidden=256):
58
+ super().__init__()
59
+ self.unit_embed = nn.Linear(input_channel, n_hidden)
60
+ self.f0_embed = nn.Linear(1, n_hidden)
61
+ self.volume_embed = nn.Linear(1, n_hidden)
62
+ if use_pitch_aug:
63
+ self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False)
64
+ else:
65
+ self.aug_shift_embed = None
66
+ self.n_spk = n_spk
67
+ if n_spk is not None and n_spk > 1:
68
+ self.spk_embed = nn.Embedding(n_spk, n_hidden)
69
+
70
+ # diffusion
71
+ self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden), out_dims=out_dims)
72
+
73
+ def forward(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
74
+ gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True):
75
+
76
+ '''
77
+ input:
78
+ B x n_frames x n_unit
79
+ return:
80
+ dict of B x n_frames x feat
81
+ '''
82
+
83
+ x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
84
+ if self.n_spk is not None and self.n_spk > 1:
85
+ if spk_mix_dict is not None:
86
+ for k, v in spk_mix_dict.items():
87
+ spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
88
+ x = x + v * self.spk_embed(spk_id_torch - 1)
89
+ else:
90
+ x = x + self.spk_embed(spk_id - 1)
91
+ if self.aug_shift_embed is not None and aug_shift is not None:
92
+ x = x + self.aug_shift_embed(aug_shift / 5)
93
+ x = self.decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm)
94
+
95
+ return x
96
+
DDSP-SVC/diffusion/vocoder.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from nsf_hifigan.nvSTFT import STFT
3
+ from nsf_hifigan.models import load_model
4
+ from torchaudio.transforms import Resample
5
+
6
+
7
+ class Vocoder:
8
+ def __init__(self, vocoder_type, vocoder_ckpt, device = None):
9
+ if device is None:
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+ self.device = device
12
+
13
+ if vocoder_type == 'nsf-hifigan':
14
+ self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device)
15
+ elif vocoder_type == 'nsf-hifigan-log10':
16
+ self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device)
17
+ else:
18
+ raise ValueError(f" [x] Unknown vocoder: {vocoder_type}")
19
+
20
+ self.resample_kernel = {}
21
+ self.vocoder_sample_rate = self.vocoder.sample_rate()
22
+ self.vocoder_hop_size = self.vocoder.hop_size()
23
+ self.dimension = self.vocoder.dimension()
24
+
25
+ def extract(self, audio, sample_rate, keyshift=0):
26
+
27
+ # resample
28
+ if sample_rate == self.vocoder_sample_rate:
29
+ audio_res = audio
30
+ else:
31
+ key_str = str(sample_rate)
32
+ if key_str not in self.resample_kernel:
33
+ self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device)
34
+ audio_res = self.resample_kernel[key_str](audio)
35
+
36
+ # extract
37
+ mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins
38
+ return mel
39
+
40
+ def infer(self, mel, f0):
41
+ f0 = f0[:,:mel.size(1),0] # B, n_frames
42
+ audio = self.vocoder(mel, f0)
43
+ return audio
44
+
45
+
46
+ class NsfHifiGAN(torch.nn.Module):
47
+ def __init__(self, model_path, device=None):
48
+ super().__init__()
49
+ if device is None:
50
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
51
+ self.device = device
52
+ print('| Load HifiGAN: ', model_path)
53
+ self.model, self.h = load_model(model_path, device=self.device)
54
+ self.stft = STFT(
55
+ self.h.sampling_rate,
56
+ self.h.num_mels,
57
+ self.h.n_fft,
58
+ self.h.win_size,
59
+ self.h.hop_size,
60
+ self.h.fmin,
61
+ self.h.fmax)
62
+
63
+ def sample_rate(self):
64
+ return self.h.sampling_rate
65
+
66
+ def hop_size(self):
67
+ return self.h.hop_size
68
+
69
+ def dimension(self):
70
+ return self.h.num_mels
71
+
72
+ def extract(self, audio, keyshift=0):
73
+ mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins
74
+ return mel
75
+
76
+ def forward(self, mel, f0):
77
+ with torch.no_grad():
78
+ c = mel.transpose(1, 2)
79
+ audio = self.model(c, f0)
80
+ return audio
81
+
82
+ class NsfHifiGANLog10(NsfHifiGAN):
83
+ def forward(self, mel, f0):
84
+ with torch.no_grad():
85
+ c = 0.434294 * mel.transpose(1, 2)
86
+ audio = self.model(c, f0)
87
+ return audio
DDSP-SVC/diffusion/wavenet.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from math import sqrt
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn import Mish
8
+
9
+
10
+ class Conv1d(torch.nn.Conv1d):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__(*args, **kwargs)
13
+ nn.init.kaiming_normal_(self.weight)
14
+
15
+
16
+ class SinusoidalPosEmb(nn.Module):
17
+ def __init__(self, dim):
18
+ super().__init__()
19
+ self.dim = dim
20
+
21
+ def forward(self, x):
22
+ device = x.device
23
+ half_dim = self.dim // 2
24
+ emb = math.log(10000) / (half_dim - 1)
25
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
26
+ emb = x[:, None] * emb[None, :]
27
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
28
+ return emb
29
+
30
+
31
+ class ResidualBlock(nn.Module):
32
+ def __init__(self, encoder_hidden, residual_channels, dilation):
33
+ super().__init__()
34
+ self.residual_channels = residual_channels
35
+ self.dilated_conv = nn.Conv1d(
36
+ residual_channels,
37
+ 2 * residual_channels,
38
+ kernel_size=3,
39
+ padding=dilation,
40
+ dilation=dilation
41
+ )
42
+ self.diffusion_projection = nn.Linear(residual_channels, residual_channels)
43
+ self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1)
44
+ self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)
45
+
46
+ def forward(self, x, conditioner, diffusion_step):
47
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
48
+ conditioner = self.conditioner_projection(conditioner)
49
+ y = x + diffusion_step
50
+
51
+ y = self.dilated_conv(y) + conditioner
52
+
53
+ # Using torch.split instead of torch.chunk to avoid using onnx::Slice
54
+ gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
55
+ y = torch.sigmoid(gate) * torch.tanh(filter)
56
+
57
+ y = self.output_projection(y)
58
+
59
+ # Using torch.split instead of torch.chunk to avoid using onnx::Slice
60
+ residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
61
+ return (x + residual) / math.sqrt(2.0), skip
62
+
63
+
64
+ class WaveNet(nn.Module):
65
+ def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256):
66
+ super().__init__()
67
+ self.input_projection = Conv1d(in_dims, n_chans, 1)
68
+ self.diffusion_embedding = SinusoidalPosEmb(n_chans)
69
+ self.mlp = nn.Sequential(
70
+ nn.Linear(n_chans, n_chans * 4),
71
+ Mish(),
72
+ nn.Linear(n_chans * 4, n_chans)
73
+ )
74
+ self.residual_layers = nn.ModuleList([
75
+ ResidualBlock(
76
+ encoder_hidden=n_hidden,
77
+ residual_channels=n_chans,
78
+ dilation=1
79
+ )
80
+ for i in range(n_layers)
81
+ ])
82
+ self.skip_projection = Conv1d(n_chans, n_chans, 1)
83
+ self.output_projection = Conv1d(n_chans, in_dims, 1)
84
+ nn.init.zeros_(self.output_projection.weight)
85
+
86
+ def forward(self, spec, diffusion_step, cond):
87
+ """
88
+ :param spec: [B, 1, M, T]
89
+ :param diffusion_step: [B, 1]
90
+ :param cond: [B, M, T]
91
+ :return:
92
+ """
93
+ x = spec.squeeze(1)
94
+ x = self.input_projection(x) # [B, residual_channel, T]
95
+
96
+ x = F.relu(x)
97
+ diffusion_step = self.diffusion_embedding(diffusion_step)
98
+ diffusion_step = self.mlp(diffusion_step)
99
+ skip = []
100
+ for layer in self.residual_layers:
101
+ x, skip_connection = layer(x, cond, diffusion_step)
102
+ skip.append(skip_connection)
103
+
104
+ x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
105
+ x = self.skip_projection(x)
106
+ x = F.relu(x)
107
+ x = self.output_projection(x) # [B, mel_bins, T]
108
+ return x[:, None, :, :]
DDSP-SVC/draw.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tqdm
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ import shutil
6
+ import wave
7
+
8
+ WAV_MIN_LENGTH = 2 # wav文件的最短时长 / The minimum duration of wav files
9
+ SAMPLE_RATE = 1 # 抽取文件数量的百分比 / The percentage of files to be extracted
10
+ SAMPLE_MIN = 2 # 抽取的文件数量下限 / The lower limit of the number of files to be extracted
11
+ SAMPLE_MAX = 10 # 抽取的文件数量上限 / The upper limit of the number of files to be extracted
12
+
13
+
14
+ # 定义一个函数,用于检查wav文件的时长是否大于最短时长
15
+ def check_duration(wav_file):
16
+ # 打开wav文件
17
+ f = wave.open(wav_file, "rb")
18
+ # 获取帧数和帧率
19
+ frames = f.getnframes()
20
+ rate = f.getframerate()
21
+ # 计算时长(秒)
22
+ duration = frames / float(rate)
23
+ # 关闭文件
24
+ f.close()
25
+ # 返回时长是否大于最短时长的布尔值
26
+ return duration > WAV_MIN_LENGTH
27
+
28
+ # 定义一个函数,用于从给定的目录中随机抽取一定比例的wav文件,并剪切到另一个目录中,保留数据结构
29
+ def split_data(src_dir, dst_dir, ratio):
30
+ # 创建目标目录(如果不存在)
31
+ if not os.path.exists(dst_dir):
32
+ os.makedirs(dst_dir)
33
+
34
+ # 获取源目录下所有的子目录和文件名
35
+ subdirs, files, subfiles = [], [], []
36
+ for item in os.listdir(src_dir):
37
+ item_path = os.path.join(src_dir, item)
38
+ if os.path.isdir(item_path):
39
+ subdirs.append(item)
40
+ for subitem in os.listdir(item_path):
41
+ subitem_path = os.path.join(item_path, subitem)
42
+ if os.path.isfile(subitem_path) and subitem.endswith(".wav"):
43
+ subfiles.append(subitem)
44
+ elif os.path.isfile(item_path) and item.endswith(".wav"):
45
+ files.append(item)
46
+
47
+ # 如果源目录下没有任何wav文件,则报错并退出函数
48
+ if len(files) == 0:
49
+ if len(subfiles) == 0:
50
+ print(f"Error: No wav files found in {src_dir}")
51
+ return
52
+
53
+ # 计算需要抽取的wav文件数量
54
+ num_files = int(len(files) * ratio)
55
+ num_files = max(SAMPLE_MIN, min(SAMPLE_MAX, num_files))
56
+
57
+ # 随机打乱文件名列表,并取出前num_files个作为抽取结果
58
+ np.random.shuffle(files)
59
+ selected_files = files[:num_files]
60
+
61
+ # 创建一个进度条对象,用于显示程序的运行进度
62
+ pbar = tqdm.tqdm(total=num_files)
63
+
64
+ # 遍历抽取结果中的每个文件名,检查是否大于2秒
65
+ for file in selected_files:
66
+ src_file = os.path.join(src_dir, file)
67
+ # 检查源文件的时长是否大于2秒,如果不是,则打印源文件的文件名,并跳过该文件
68
+ if not check_duration(src_file):
69
+ print(f"Skipped {src_file} because its duration is less than 2 seconds.")
70
+ continue
71
+ # 拼接源文件和目标文件的完整路径,移动文件,并更新进度条
72
+ dst_file = os.path.join(dst_dir, file)
73
+ shutil.move(src_file, dst_file)
74
+ pbar.update(1)
75
+
76
+ pbar.close()
77
+
78
+ # 遍历源目录下所有的子目录(如果有)
79
+ for subdir in subdirs:
80
+ # 拼接子目录在源目录和目标目录中的完整路径
81
+ src_subdir = os.path.join(src_dir, subdir)
82
+ dst_subdir = os.path.join(dst_dir, subdir)
83
+ # 递归地调用本函数,对子目录中的wav文件进行同样的操作,保留数据结构
84
+ split_data(src_subdir, dst_subdir, ratio)
85
+
86
+ # 定义主函数,用于获取用户输入并调用上述函数
87
+
88
+ def main():
89
+ root_dir = os.path.abspath('.')
90
+ dst_dir = root_dir + "/data/val/audio"
91
+ # 抽取比例,默认为1
92
+ ratio = float(SAMPLE_RATE) / 100
93
+
94
+ # 固定源目录为根目录下/data/train/audio目录
95
+ src_dir = root_dir + "/data/train/audio"
96
+
97
+ # 调用split_data函数,对源目录中的wav文件进行抽取,并剪切到目标目录中,保留数据结构
98
+ split_data(src_dir, dst_dir, ratio)
99
+
100
+ # 如果本模块是主模块,则执行主函数
101
+ if __name__ == "__main__":
102
+ main()
DDSP-SVC/encoder/hubert/model.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional, Tuple
3
+ import random
4
+
5
+ from sklearn.cluster import KMeans
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
11
+
12
+ URLS = {
13
+ "hubert-discrete": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-discrete-e9416457.pt",
14
+ "hubert-soft": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt",
15
+ "kmeans100": "https://github.com/bshall/hubert/releases/download/v0.1/kmeans100-50f36a95.pt",
16
+ }
17
+
18
+
19
+ class Hubert(nn.Module):
20
+ def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
21
+ super().__init__()
22
+ self._mask = mask
23
+ self.feature_extractor = FeatureExtractor()
24
+ self.feature_projection = FeatureProjection()
25
+ self.positional_embedding = PositionalConvEmbedding()
26
+ self.norm = nn.LayerNorm(768)
27
+ self.dropout = nn.Dropout(0.1)
28
+ self.encoder = TransformerEncoder(
29
+ nn.TransformerEncoderLayer(
30
+ 768, 12, 3072, activation="gelu", batch_first=True
31
+ ),
32
+ 12,
33
+ )
34
+ self.proj = nn.Linear(768, 256)
35
+
36
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
37
+ self.label_embedding = nn.Embedding(num_label_embeddings, 256)
38
+
39
+ def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
40
+ mask = None
41
+ if self.training and self._mask:
42
+ mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
43
+ x[mask] = self.masked_spec_embed.to(x.dtype)
44
+ return x, mask
45
+
46
+ def encode(
47
+ self, x: torch.Tensor, layer: Optional[int] = None
48
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
49
+ x = self.feature_extractor(x)
50
+ x = self.feature_projection(x.transpose(1, 2))
51
+ x, mask = self.mask(x)
52
+ x = x + self.positional_embedding(x)
53
+ x = self.dropout(self.norm(x))
54
+ x = self.encoder(x, output_layer=layer)
55
+ return x, mask
56
+
57
+ def logits(self, x: torch.Tensor) -> torch.Tensor:
58
+ logits = torch.cosine_similarity(
59
+ x.unsqueeze(2),
60
+ self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
61
+ dim=-1,
62
+ )
63
+ return logits / 0.1
64
+
65
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
66
+ x, mask = self.encode(x)
67
+ x = self.proj(x)
68
+ logits = self.logits(x)
69
+ return logits, mask
70
+
71
+
72
+ class HubertSoft(Hubert):
73
+ def __init__(self):
74
+ super().__init__()
75
+
76
+ @torch.inference_mode()
77
+ def units(self, wav: torch.Tensor) -> torch.Tensor:
78
+ wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
79
+ x, _ = self.encode(wav)
80
+ return self.proj(x)
81
+
82
+
83
+ class HubertDiscrete(Hubert):
84
+ def __init__(self, kmeans):
85
+ super().__init__(504)
86
+ self.kmeans = kmeans
87
+
88
+ @torch.inference_mode()
89
+ def units(self, wav: torch.Tensor) -> torch.LongTensor:
90
+ wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
91
+ x, _ = self.encode(wav, layer=7)
92
+ x = self.kmeans.predict(x.squeeze().cpu().numpy())
93
+ return torch.tensor(x, dtype=torch.long, device=wav.device)
94
+
95
+
96
+ class FeatureExtractor(nn.Module):
97
+ def __init__(self):
98
+ super().__init__()
99
+ self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
100
+ self.norm0 = nn.GroupNorm(512, 512)
101
+ self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
102
+ self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
103
+ self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
104
+ self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
105
+ self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
106
+ self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ x = F.gelu(self.norm0(self.conv0(x)))
110
+ x = F.gelu(self.conv1(x))
111
+ x = F.gelu(self.conv2(x))
112
+ x = F.gelu(self.conv3(x))
113
+ x = F.gelu(self.conv4(x))
114
+ x = F.gelu(self.conv5(x))
115
+ x = F.gelu(self.conv6(x))
116
+ return x
117
+
118
+
119
+ class FeatureProjection(nn.Module):
120
+ def __init__(self):
121
+ super().__init__()
122
+ self.norm = nn.LayerNorm(512)
123
+ self.projection = nn.Linear(512, 768)
124
+ self.dropout = nn.Dropout(0.1)
125
+
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
+ x = self.norm(x)
128
+ x = self.projection(x)
129
+ x = self.dropout(x)
130
+ return x
131
+
132
+
133
+ class PositionalConvEmbedding(nn.Module):
134
+ def __init__(self):
135
+ super().__init__()
136
+ self.conv = nn.Conv1d(
137
+ 768,
138
+ 768,
139
+ kernel_size=128,
140
+ padding=128 // 2,
141
+ groups=16,
142
+ )
143
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
144
+
145
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
146
+ x = self.conv(x.transpose(1, 2))
147
+ x = F.gelu(x[:, :, :-1])
148
+ return x.transpose(1, 2)
149
+
150
+
151
+ class TransformerEncoder(nn.Module):
152
+ def __init__(
153
+ self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
154
+ ) -> None:
155
+ super(TransformerEncoder, self).__init__()
156
+ self.layers = nn.ModuleList(
157
+ [copy.deepcopy(encoder_layer) for _ in range(num_layers)]
158
+ )
159
+ self.num_layers = num_layers
160
+
161
+ def forward(
162
+ self,
163
+ src: torch.Tensor,
164
+ mask: torch.Tensor = None,
165
+ src_key_padding_mask: torch.Tensor = None,
166
+ output_layer: Optional[int] = None,
167
+ ) -> torch.Tensor:
168
+ output = src
169
+ for layer in self.layers[:output_layer]:
170
+ output = layer(
171
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
172
+ )
173
+ return output
174
+
175
+
176
+ def _compute_mask(
177
+ shape: Tuple[int, int],
178
+ mask_prob: float,
179
+ mask_length: int,
180
+ device: torch.device,
181
+ min_masks: int = 0,
182
+ ) -> torch.Tensor:
183
+ batch_size, sequence_length = shape
184
+
185
+ if mask_length < 1:
186
+ raise ValueError("`mask_length` has to be bigger than 0.")
187
+
188
+ if mask_length > sequence_length:
189
+ raise ValueError(
190
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
191
+ )
192
+
193
+ # compute number of masked spans in batch
194
+ num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
195
+ num_masked_spans = max(num_masked_spans, min_masks)
196
+
197
+ # make sure num masked indices <= sequence_length
198
+ if num_masked_spans * mask_length > sequence_length:
199
+ num_masked_spans = sequence_length // mask_length
200
+
201
+ # SpecAugment mask to fill
202
+ mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
203
+
204
+ # uniform distribution to sample from, make sure that offset samples are < sequence_length
205
+ uniform_dist = torch.ones(
206
+ (batch_size, sequence_length - (mask_length - 1)), device=device
207
+ )
208
+
209
+ # get random indices to mask
210
+ mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
211
+
212
+ # expand masked indices to masked spans
213
+ mask_indices = (
214
+ mask_indices.unsqueeze(dim=-1)
215
+ .expand((batch_size, num_masked_spans, mask_length))
216
+ .reshape(batch_size, num_masked_spans * mask_length)
217
+ )
218
+ offsets = (
219
+ torch.arange(mask_length, device=device)[None, None, :]
220
+ .expand((batch_size, num_masked_spans, mask_length))
221
+ .reshape(batch_size, num_masked_spans * mask_length)
222
+ )
223
+ mask_idxs = mask_indices + offsets
224
+
225
+ # scatter indices to mask
226
+ mask = mask.scatter(1, mask_idxs, True)
227
+
228
+ return mask
229
+
230
+
231
+ def hubert_discrete(
232
+ pretrained: bool = True,
233
+ progress: bool = True,
234
+ ) -> HubertDiscrete:
235
+ r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
236
+ Args:
237
+ pretrained (bool): load pretrained weights into the model
238
+ progress (bool): show progress bar when downloading model
239
+ """
240
+ kmeans = kmeans100(pretrained=pretrained, progress=progress)
241
+ hubert = HubertDiscrete(kmeans)
242
+ if pretrained:
243
+ checkpoint = torch.hub.load_state_dict_from_url(
244
+ URLS["hubert-discrete"], progress=progress
245
+ )
246
+ consume_prefix_in_state_dict_if_present(checkpoint, "module.")
247
+ hubert.load_state_dict(checkpoint)
248
+ hubert.eval()
249
+ return hubert
250
+
251
+
252
+ def hubert_soft(
253
+ pretrained: bool = True,
254
+ progress: bool = True,
255
+ ) -> HubertSoft:
256
+ r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
257
+ Args:
258
+ pretrained (bool): load pretrained weights into the model
259
+ progress (bool): show progress bar when downloading model
260
+ """
261
+ hubert = HubertSoft()
262
+ if pretrained:
263
+ checkpoint = torch.hub.load_state_dict_from_url(
264
+ URLS["hubert-soft"], progress=progress
265
+ )
266
+ consume_prefix_in_state_dict_if_present(checkpoint, "module.")
267
+ hubert.load_state_dict(checkpoint)
268
+ hubert.eval()
269
+ return hubert
270
+
271
+
272
+ def _kmeans(
273
+ num_clusters: int, pretrained: bool = True, progress: bool = True
274
+ ) -> KMeans:
275
+ kmeans = KMeans(num_clusters)
276
+ if pretrained:
277
+ checkpoint = torch.hub.load_state_dict_from_url(
278
+ URLS[f"kmeans{num_clusters}"], progress=progress
279
+ )
280
+ kmeans.__dict__["n_features_in_"] = checkpoint["n_features_in_"]
281
+ kmeans.__dict__["_n_threads"] = checkpoint["_n_threads"]
282
+ kmeans.__dict__["cluster_centers_"] = checkpoint["cluster_centers_"].numpy()
283
+ return kmeans
284
+
285
+
286
+ def kmeans100(pretrained: bool = True, progress: bool = True) -> KMeans:
287
+ r"""
288
+ k-means checkpoint for HuBERT-Discrete with 100 clusters.
289
+ Args:
290
+ pretrained (bool): load pretrained weights into the model
291
+ progress (bool): show progress bar when downloading model
292
+ """
293
+ return _kmeans(100, pretrained, progress)
DDSP-SVC/enhancer.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from nsf_hifigan.nvSTFT import STFT
5
+ from nsf_hifigan.models import load_model
6
+ from torchaudio.transforms import Resample
7
+
8
+ class Enhancer:
9
+ def __init__(self, enhancer_type, enhancer_ckpt, device=None):
10
+ if device is None:
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ self.device = device
13
+
14
+ if enhancer_type == 'nsf-hifigan':
15
+ self.enhancer = NsfHifiGAN(enhancer_ckpt, device=self.device)
16
+ else:
17
+ raise ValueError(f" [x] Unknown enhancer: {enhancer_type}")
18
+
19
+ self.resample_kernel = {}
20
+ self.enhancer_sample_rate = self.enhancer.sample_rate()
21
+ self.enhancer_hop_size = self.enhancer.hop_size()
22
+
23
+ def enhance(self,
24
+ audio, # 1, T
25
+ sample_rate,
26
+ f0, # 1, n_frames, 1
27
+ hop_size,
28
+ adaptive_key = 0,
29
+ silence_front = 0
30
+ ):
31
+ # enhancer start time
32
+ start_frame = int(silence_front * sample_rate / hop_size)
33
+ real_silence_front = start_frame * hop_size / sample_rate
34
+ audio = audio[:, int(np.round(real_silence_front * sample_rate)) : ]
35
+ f0 = f0[: , start_frame :, :]
36
+
37
+ # adaptive parameters
38
+ if adaptive_key == 'auto':
39
+ adaptive_key = 12 * np.log2(float(torch.max(f0) / 760))
40
+ adaptive_key = max(0, np.ceil(adaptive_key))
41
+ print('auto_adaptive_key: ' + str(int(adaptive_key)))
42
+ else:
43
+ adaptive_key = float(adaptive_key)
44
+
45
+ adaptive_factor = 2 ** ( -adaptive_key / 12)
46
+ adaptive_sample_rate = 100 * int(np.round(self.enhancer_sample_rate / adaptive_factor / 100))
47
+ real_factor = self.enhancer_sample_rate / adaptive_sample_rate
48
+
49
+ # resample the ddsp output
50
+ if sample_rate == adaptive_sample_rate:
51
+ audio_res = audio
52
+ else:
53
+ key_str = str(sample_rate) + str(adaptive_sample_rate)
54
+ if key_str not in self.resample_kernel:
55
+ self.resample_kernel[key_str] = Resample(sample_rate, adaptive_sample_rate, lowpass_filter_width = 128).to(self.device)
56
+ audio_res = self.resample_kernel[key_str](audio)
57
+
58
+ n_frames = int(audio_res.size(-1) // self.enhancer_hop_size + 1)
59
+
60
+ # resample f0
61
+ if hop_size == self.enhancer_hop_size and sample_rate == self.enhancer_sample_rate and sample_rate == adaptive_sample_rate:
62
+ f0_res = f0.squeeze(-1) # 1, n_frames
63
+ else:
64
+ f0_np = f0.squeeze(0).squeeze(-1).cpu().numpy()
65
+ f0_np *= real_factor
66
+ time_org = (hop_size / sample_rate) * np.arange(len(f0_np)) / real_factor
67
+ time_frame = (self.enhancer_hop_size / self.enhancer_sample_rate) * np.arange(n_frames)
68
+ f0_res = np.interp(time_frame, time_org, f0_np, left=f0_np[0], right=f0_np[-1])
69
+ f0_res = torch.from_numpy(f0_res).unsqueeze(0).float().to(self.device) # 1, n_frames
70
+
71
+ # enhance
72
+ enhanced_audio, enhancer_sample_rate = self.enhancer(audio_res, f0_res)
73
+
74
+ # resample the enhanced output
75
+ if adaptive_sample_rate != enhancer_sample_rate:
76
+ key_str = str(adaptive_sample_rate) + str(enhancer_sample_rate)
77
+ if key_str not in self.resample_kernel:
78
+ self.resample_kernel[key_str] = Resample(adaptive_sample_rate, enhancer_sample_rate, lowpass_filter_width = 128).to(self.device)
79
+ enhanced_audio = self.resample_kernel[key_str](enhanced_audio)
80
+
81
+ # pad the silence frames
82
+ if start_frame > 0:
83
+ enhanced_audio = F.pad(enhanced_audio, (int(np.round(enhancer_sample_rate * real_silence_front)), 0))
84
+
85
+ return enhanced_audio, enhancer_sample_rate
86
+
87
+
88
+ class NsfHifiGAN(torch.nn.Module):
89
+ def __init__(self, model_path, device=None):
90
+ super().__init__()
91
+ if device is None:
92
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
93
+ self.device = device
94
+ print('| Load HifiGAN: ', model_path)
95
+ self.model, self.h = load_model(model_path, device=self.device)
96
+ self.stft = STFT(
97
+ self.h.sampling_rate,
98
+ self.h.num_mels,
99
+ self.h.n_fft,
100
+ self.h.win_size,
101
+ self.h.hop_size,
102
+ self.h.fmin,
103
+ self.h.fmax)
104
+
105
+ def sample_rate(self):
106
+ return self.h.sampling_rate
107
+
108
+ def hop_size(self):
109
+ return self.h.hop_size
110
+
111
+ def forward(self, audio, f0):
112
+ with torch.no_grad():
113
+ mel = self.stft.get_mel(audio)
114
+ enhanced_audio = self.model(mel, f0[:,:mel.size(-1)])
115
+ return enhanced_audio, self.h.sampling_rate
DDSP-SVC/exp/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
DDSP-SVC/flask_api.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ import slicer
6
+ import soundfile as sf
7
+ import librosa
8
+ from flask import Flask, request, send_file
9
+ from flask_cors import CORS
10
+
11
+ from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
12
+ from ddsp.core import upsample
13
+ from enhancer import Enhancer
14
+
15
+
16
+ app = Flask(__name__)
17
+
18
+ CORS(app)
19
+
20
+ logging.getLogger("numba").setLevel(logging.WARNING)
21
+
22
+
23
+ @app.route("/voiceChangeModel", methods=["POST"])
24
+ def voice_change_model():
25
+ request_form = request.form
26
+ wave_file = request.files.get("sample", None)
27
+ # get fSafePrefixPadLength
28
+ f_safe_prefix_pad_length = float(request_form.get("fSafePrefixPadLength", 0))
29
+ print("f_safe_prefix_pad_length:"+str(f_safe_prefix_pad_length))
30
+ # 变调信息
31
+ f_pitch_change = float(request_form.get("fPitchChange", 0))
32
+ # 获取spk_id
33
+ int_speak_id = int(request_form.get("sSpeakId", 0))
34
+ if enable_spk_id_cover:
35
+ int_speak_id = spk_id
36
+ # print("说话人:" + str(int_speak_id))
37
+ # DAW所需的采样率
38
+ daw_sample = int(float(request_form.get("sampleRate", 0)))
39
+ # http获得wav文件并转换
40
+ input_wav_read = io.BytesIO(wave_file.read())
41
+ # 模型推理
42
+ _audio, _model_sr = svc_model.infer(input_wav_read, f_pitch_change, int_speak_id, f_safe_prefix_pad_length)
43
+ tar_audio = librosa.resample(_audio, _model_sr, daw_sample)
44
+ # 返回音频
45
+ out_wav_path = io.BytesIO()
46
+ sf.write(out_wav_path, tar_audio, daw_sample, format="wav")
47
+ out_wav_path.seek(0)
48
+ return send_file(out_wav_path, download_name="temp.wav", as_attachment=True)
49
+
50
+
51
+ class SvcDDSP:
52
+ def __init__(self, model_path, vocoder_based_enhancer, enhancer_adaptive_key, input_pitch_extractor,
53
+ f0_min, f0_max, threhold, spk_id, spk_mix_dict, enable_spk_id_cover):
54
+ self.model_path = model_path
55
+ self.vocoder_based_enhancer = vocoder_based_enhancer
56
+ self.enhancer_adaptive_key = enhancer_adaptive_key
57
+ self.input_pitch_extractor = input_pitch_extractor
58
+ self.f0_min = f0_min
59
+ self.f0_max = f0_max
60
+ self.threhold = threhold
61
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
62
+ self.spk_id = spk_id
63
+ self.spk_mix_dict = spk_mix_dict
64
+ self.enable_spk_id_cover = enable_spk_id_cover
65
+
66
+ # load ddsp model
67
+ self.model, self.args = load_model(self.model_path, device=self.device)
68
+
69
+ # load units encoder
70
+ if self.args.data.encoder == 'cnhubertsoftfish':
71
+ cnhubertsoft_gate = self.args.data.cnhubertsoft_gate
72
+ else:
73
+ cnhubertsoft_gate = 10
74
+ self.units_encoder = Units_Encoder(
75
+ self.args.data.encoder,
76
+ self.args.data.encoder_ckpt,
77
+ self.args.data.encoder_sample_rate,
78
+ self.args.data.encoder_hop_size,
79
+ cnhubertsoft_gate=cnhubertsoft_gate,
80
+ device=self.device)
81
+
82
+ # load enhancer
83
+ if self.vocoder_based_enhancer:
84
+ self.enhancer = Enhancer(self.args.enhancer.type, self.args.enhancer.ckpt, device=self.device)
85
+
86
+ def infer(self, input_wav, pitch_adjust, speaker_id, safe_prefix_pad_length):
87
+ print("Infer!")
88
+ # load input
89
+ audio, sample_rate = librosa.load(input_wav, sr=None, mono=True)
90
+ if len(audio.shape) > 1:
91
+ audio = librosa.to_mono(audio)
92
+ hop_size = self.args.data.block_size * sample_rate / self.args.data.sampling_rate
93
+
94
+ # safe front silence
95
+ if safe_prefix_pad_length > 0.03:
96
+ silence_front = safe_prefix_pad_length - 0.03
97
+ else:
98
+ silence_front = 0
99
+
100
+ # extract f0
101
+ pitch_extractor = F0_Extractor(
102
+ self.input_pitch_extractor,
103
+ sample_rate,
104
+ hop_size,
105
+ float(self.f0_min),
106
+ float(self.f0_max))
107
+ f0 = pitch_extractor.extract(audio, uv_interp=True, device=self.device, silence_front=silence_front)
108
+ f0 = torch.from_numpy(f0).float().to(self.device).unsqueeze(-1).unsqueeze(0)
109
+ f0 = f0 * 2 ** (float(pitch_adjust) / 12)
110
+
111
+ # extract volume
112
+ volume_extractor = Volume_Extractor(hop_size)
113
+ volume = volume_extractor.extract(audio)
114
+ mask = (volume > 10 ** (float(self.threhold) / 20)).astype('float')
115
+ mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
116
+ mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)])
117
+ mask = torch.from_numpy(mask).float().to(self.device).unsqueeze(-1).unsqueeze(0)
118
+ mask = upsample(mask, self.args.data.block_size).squeeze(-1)
119
+ volume = torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0)
120
+
121
+ # extract units
122
+ audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
123
+ units = self.units_encoder.encode(audio_t, sample_rate, hop_size)
124
+
125
+ # spk_id or spk_mix_dict
126
+ if self.enable_spk_id_cover:
127
+ spk_id = self.spk_id
128
+ else:
129
+ spk_id = speaker_id
130
+ spk_id = torch.LongTensor(np.array([[spk_id]])).to(self.device)
131
+
132
+ # forward and return the output
133
+ with torch.no_grad():
134
+ output, _, (s_h, s_n) = self.model(units, f0, volume, spk_id = spk_id, spk_mix_dict = self.spk_mix_dict)
135
+ output *= mask
136
+ if self.vocoder_based_enhancer:
137
+ output, output_sample_rate = self.enhancer.enhance(
138
+ output,
139
+ self.args.data.sampling_rate,
140
+ f0,
141
+ self.args.data.block_size,
142
+ adaptive_key = self.enhancer_adaptive_key,
143
+ silence_front = silence_front)
144
+ else:
145
+ output_sample_rate = self.args.data.sampling_rate
146
+
147
+ output = output.squeeze().cpu().numpy()
148
+ return output, output_sample_rate
149
+
150
+
151
+ if __name__ == "__main__":
152
+ # ddsp-svc下只需传入下列参数。
153
+ # 对接的是串串香火锅大佬https://github.com/zhaohui8969/VST_NetProcess-。建议使用最新版本。
154
+ # flask部分来自diffsvc小狼大佬编写的代码。
155
+ # config和模型得同一目录。
156
+ checkpoint_path = "exp/multi_speaker/model_300000.pt"
157
+ # 是否使用预训练的基于声码器的增强器增强输出,但对硬件要求更高。
158
+ use_vocoder_based_enhancer = True
159
+ # 结合增强器使用,0为正常音域范围(最高G5)内的高音频质量,大于0则可以防止超高音破音
160
+ enhancer_adaptive_key = 0
161
+ # f0提取器,有parselmouth, dio, harvest, crepe
162
+ select_pitch_extractor = 'crepe'
163
+ # f0范围限制(Hz)
164
+ limit_f0_min = 50
165
+ limit_f0_max = 1100
166
+ # 音量响应阈值(dB)
167
+ threhold = -60
168
+ # 默认说话人。以及是否优先使用默认说话人覆盖vst传入的参数。
169
+ spk_id = 1
170
+ enable_spk_id_cover = True
171
+ # 混合说话人字典(捏音色功能)
172
+ # 设置为非 None 字典会覆盖 spk_id
173
+ spk_mix_dict = None # {1:0.5, 2:0.5} 表示1号说话人和2号说话人的音色按照0.5:0.5的比例混合
174
+ svc_model = SvcDDSP(checkpoint_path, use_vocoder_based_enhancer, enhancer_adaptive_key, select_pitch_extractor,
175
+ limit_f0_min, limit_f0_max, threhold, spk_id, spk_mix_dict, enable_spk_id_cover)
176
+
177
+ # 此处与vst插件对应,端口必须接上。
178
+ app.run(port=6844, host="0.0.0.0", debug=False, threaded=False)
DDSP-SVC/gui.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PySimpleGUI as sg
2
+ import sounddevice as sd
3
+ import torch, librosa, threading, pickle
4
+ from enhancer import Enhancer
5
+ import numpy as np
6
+ from torch.nn import functional as F
7
+ from torchaudio.transforms import Resample
8
+ from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
9
+ from ddsp.core import upsample
10
+ import time
11
+ import gui_locale
12
+
13
+
14
+ def phase_vocoder(a, b, fade_out, fade_in):
15
+ fa = torch.fft.rfft(a)
16
+ fb = torch.fft.rfft(b)
17
+ absab = torch.abs(fa) + torch.abs(fb)
18
+ n = a.shape[0]
19
+ if n % 2 == 0:
20
+ absab[1:-1] *= 2
21
+ else:
22
+ absab[1:] *= 2
23
+ phia = torch.angle(fa)
24
+ phib = torch.angle(fb)
25
+ deltaphase = phib - phia
26
+ deltaphase = deltaphase - 2 * np.pi * torch.floor(deltaphase / 2 / np.pi + 0.5)
27
+ w = 2 * np.pi * torch.arange(n // 2 + 1).to(a) + deltaphase
28
+ t = torch.arange(n).unsqueeze(-1).to(a) / n
29
+ result = a * (fade_out ** 2) + b * (fade_in ** 2) + torch.sum(absab * torch.cos(w * t + phia),
30
+ -1) * fade_out * fade_in / n
31
+ return result
32
+
33
+
34
+ class SvcDDSP:
35
+ def __init__(self) -> None:
36
+ self.model = None
37
+ self.units_encoder = None
38
+ self.encoder_type = None
39
+ self.encoder_ckpt = None
40
+ self.enhancer = None
41
+ self.enhancer_type = None
42
+ self.enhancer_ckpt = None
43
+
44
+ def update_model(self, model_path):
45
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
46
+
47
+ # load ddsp model
48
+ if self.model is None or self.model_path != model_path:
49
+ self.model, self.args = load_model(model_path, device=self.device)
50
+ self.model_path = model_path
51
+
52
+ # load units encoder
53
+ if self.units_encoder is None or self.args.data.encoder != self.encoder_type or self.args.data.encoder_ckpt != self.encoder_ckpt:
54
+ if self.args.data.encoder == 'cnhubertsoftfish':
55
+ cnhubertsoft_gate = self.args.data.cnhubertsoft_gate
56
+ else:
57
+ cnhubertsoft_gate = 10
58
+ self.units_encoder = Units_Encoder(
59
+ self.args.data.encoder,
60
+ self.args.data.encoder_ckpt,
61
+ self.args.data.encoder_sample_rate,
62
+ self.args.data.encoder_hop_size,
63
+ cnhubertsoft_gate=cnhubertsoft_gate,
64
+ device=self.device)
65
+ self.encoder_type = self.args.data.encoder
66
+ self.encoder_ckpt = self.args.data.encoder_ckpt
67
+
68
+ # load enhancer
69
+ if self.enhancer is None or self.args.enhancer.type != self.enhancer_type or self.args.enhancer.ckpt != self.enhancer_ckpt:
70
+ self.enhancer = Enhancer(self.args.enhancer.type, self.args.enhancer.ckpt, device=self.device)
71
+ self.enhancer_type = self.args.enhancer.type
72
+ self.enhancer_ckpt = self.args.enhancer.ckpt
73
+
74
+ def infer(self,
75
+ audio,
76
+ sample_rate,
77
+ spk_id=1,
78
+ threhold=-45,
79
+ pitch_adjust=0,
80
+ use_spk_mix=False,
81
+ spk_mix_dict=None,
82
+ use_enhancer=True,
83
+ enhancer_adaptive_key='auto',
84
+ pitch_extractor_type='crepe',
85
+ f0_min=50,
86
+ f0_max=1100,
87
+ safe_prefix_pad_length=0,
88
+ ):
89
+ print("Infering...")
90
+ # load input
91
+ # audio, sample_rate = librosa.load(input_wav, sr=None, mono=True)
92
+ hop_size = self.args.data.block_size * sample_rate / self.args.data.sampling_rate
93
+ # safe front silence
94
+ if safe_prefix_pad_length > 0.03:
95
+ silence_front = safe_prefix_pad_length - 0.03
96
+ else:
97
+ silence_front = 0
98
+
99
+ # extract f0
100
+ pitch_extractor = F0_Extractor(
101
+ pitch_extractor_type,
102
+ sample_rate,
103
+ hop_size,
104
+ float(f0_min),
105
+ float(f0_max))
106
+ f0 = pitch_extractor.extract(audio, uv_interp=True, device=self.device, silence_front=silence_front)
107
+ f0 = torch.from_numpy(f0).float().to(self.device).unsqueeze(-1).unsqueeze(0)
108
+ f0 = f0 * 2 ** (float(pitch_adjust) / 12)
109
+
110
+ # extract volume
111
+ volume_extractor = Volume_Extractor(hop_size)
112
+ volume = volume_extractor.extract(audio)
113
+ mask = (volume > 10 ** (float(threhold) / 20)).astype('float')
114
+ mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
115
+ mask = np.array([np.max(mask[n: n + 9]) for n in range(len(mask) - 8)])
116
+ mask = torch.from_numpy(mask).float().to(self.device).unsqueeze(-1).unsqueeze(0)
117
+ mask = upsample(mask, self.args.data.block_size).squeeze(-1)
118
+ volume = torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0)
119
+
120
+ # extract units
121
+ audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
122
+ units = self.units_encoder.encode(audio_t, sample_rate, hop_size)
123
+
124
+ # spk_id or spk_mix_dict
125
+ spk_id = torch.LongTensor(np.array([[spk_id]])).to(self.device)
126
+ dictionary = None
127
+ if use_spk_mix:
128
+ dictionary = spk_mix_dict
129
+
130
+ # forward and return the output
131
+ with torch.no_grad():
132
+ output, _, (s_h, s_n) = self.model(units, f0, volume, spk_id=spk_id, spk_mix_dict=dictionary)
133
+ output *= mask
134
+ if use_enhancer:
135
+ output, output_sample_rate = self.enhancer.enhance(
136
+ output,
137
+ self.args.data.sampling_rate,
138
+ f0,
139
+ self.args.data.block_size,
140
+ adaptive_key=enhancer_adaptive_key,
141
+ silence_front=silence_front)
142
+ else:
143
+ output_sample_rate = self.args.data.sampling_rate
144
+
145
+ output = output.squeeze()
146
+ return output, output_sample_rate
147
+
148
+
149
+ class Config:
150
+ def __init__(self) -> None:
151
+ self.samplerate = 44100 # Hz
152
+ self.block_time = 1.5 # s
153
+ self.f_pitch_change: float = 0.0 # float(request_form.get("fPitchChange", 0))
154
+ self.spk_id = 1 # 默认说话人。
155
+ self.spk_mix_dict = None # {1:0.5, 2:0.5} 表示1号说话人和2号说话人的音色按照0.5:0.5的比例混合
156
+ self.use_vocoder_based_enhancer = True
157
+ self.use_phase_vocoder = True
158
+ self.checkpoint_path = ''
159
+ self.threhold = -35
160
+ self.buffer_num = 2
161
+ self.crossfade_time = 0.03
162
+ self.select_pitch_extractor = 'harvest' # F0预测器["parselmouth", "dio", "harvest", "crepe"]
163
+ self.use_spk_mix = False
164
+ self.sounddevices = ['', '']
165
+
166
+ def save(self, path):
167
+ with open(path + '\\config.pkl', 'wb') as f:
168
+ pickle.dump(vars(self), f)
169
+
170
+ def load(self, path) -> bool:
171
+ try:
172
+ with open(path + '\\config.pkl', 'rb') as f:
173
+ self.update(pickle.load(f))
174
+ return True
175
+ except:
176
+ print('config.pkl does not exist')
177
+ return False
178
+
179
+ def update(self, data_dict):
180
+ for key, value in data_dict.items():
181
+ setattr(self, key, value)
182
+
183
+
184
+ class GUI:
185
+ def __init__(self) -> None:
186
+ self.config = Config()
187
+ self.flag_vc: bool = False # 变声线程flag
188
+ self.block_frame = 0
189
+ self.crossfade_frame = 0
190
+ self.sola_search_frame = 0
191
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
192
+ self.svc_model: SvcDDSP = SvcDDSP()
193
+ self.fade_in_window: np.ndarray = None # crossfade计算用numpy数组
194
+ self.fade_out_window: np.ndarray = None # crossfade计算用numpy数组
195
+ self.input_wav: np.ndarray = None # 输入音频规范化后的保存地址
196
+ self.output_wav: np.ndarray = None # 输出音频规范化后的保存地址
197
+ self.sola_buffer: torch.Tensor = None # 保存上一个output的crossfade
198
+ self.f0_mode_list = ["parselmouth", "dio", "harvest", "crepe"] # F0预测器
199
+ self.f_safe_prefix_pad_length: float = 0.0
200
+ self.resample_kernel = {}
201
+ self.launcher() # start
202
+
203
+ def launcher(self):
204
+ '''窗口加载'''
205
+ input_devices, output_devices, _, _ = self.get_devices()
206
+ sg.theme('DarkAmber') # 设置主题
207
+ # 界面布局
208
+ layout = [
209
+ [sg.Frame(layout=[
210
+ [sg.Input(key='sg_model', default_text='exp\\multi_speaker\\model_300000.pt'),
211
+ sg.FileBrowse(i18n('选择模型文件'), key='choose_model')]
212
+ ], title=i18n('模型:.pt格式(自动识别同目录下config.yaml)')),
213
+ sg.Frame(layout=[
214
+ [sg.Text(i18n('选择配置文件所在目录')), sg.Input(key='config_file_dir', default_text='exp'),
215
+ sg.FolderBrowse(i18n('打开文件夹'), key='choose_config')],
216
+ [sg.Button(i18n('读取配置文件'), key='load_config'), sg.Button(i18n('保存配置文件'), key='save_config')]
217
+ ], title=i18n('快速配置文件'))
218
+ ],
219
+ [sg.Frame(layout=[
220
+ [sg.Text(i18n("输入设备")),
221
+ sg.Combo(input_devices, key='sg_input_device', default_value=input_devices[sd.default.device[0]],
222
+ enable_events=True)],
223
+ [sg.Text(i18n("输出设备")),
224
+ sg.Combo(output_devices, key='sg_output_device', default_value=output_devices[sd.default.device[1]],
225
+ enable_events=True)]
226
+ ], title=i18n('音频设备'))
227
+ ],
228
+ [sg.Frame(layout=[
229
+ [sg.Text(i18n("说话人id")), sg.Input(key='spk_id', default_text='1')],
230
+ [sg.Text(i18n("响应阈值")),
231
+ sg.Slider(range=(-60, 0), orientation='h', key='threhold', resolution=1, default_value=-45,
232
+ enable_events=True)],
233
+ [sg.Text(i18n("变调")),
234
+ sg.Slider(range=(-24, 24), orientation='h', key='pitch', resolution=1, default_value=0,
235
+ enable_events=True)],
236
+ [sg.Text(i18n("采样率")), sg.Input(key='samplerate', default_text='44100')],
237
+ [sg.Checkbox(text=i18n('启用捏音色功能'), default=False, key='spk_mix', enable_events=True),
238
+ sg.Button(i18n("设置混合音色"), key='set_spk_mix')]
239
+ ], title=i18n('普通设置')),
240
+ sg.Frame(layout=[
241
+ [sg.Text(i18n("音频切分大小")),
242
+ sg.Slider(range=(0.05, 3.0), orientation='h', key='block', resolution=0.01, default_value=0.3,
243
+ enable_events=True)],
244
+ [sg.Text(i18n("交叉淡化时长")),
245
+ sg.Slider(range=(0.01, 0.15), orientation='h', key='crossfade', resolution=0.01,
246
+ default_value=0.04, enable_events=True)],
247
+ [sg.Text(i18n("使用历史区块数量")),
248
+ sg.Slider(range=(1, 20), orientation='h', key='buffernum', resolution=1, default_value=4,
249
+ enable_events=True)],
250
+ [sg.Text(i18n("f0预测模式")),
251
+ sg.Combo(values=self.f0_mode_list, key='f0_mode', default_value=self.f0_mode_list[2],
252
+ enable_events=True)],
253
+ [sg.Checkbox(text=i18n('启用增强器'), default=True, key='use_enhancer', enable_events=True),
254
+ sg.Checkbox(text=i18n('启用相位声码器'), default=False, key='use_phase_vocoder', enable_events=True)]
255
+ ], title=i18n('性能设置')),
256
+ ],
257
+ [sg.Button(i18n("开始音频转换"), key="start_vc"), sg.Button(i18n("停止音频转换"), key="stop_vc"),
258
+ sg.Text(i18n('推理所用时间(ms):')), sg.Text('0', key='infer_time')]
259
+ ]
260
+
261
+ # 创造窗口
262
+ self.window = sg.Window('DDSP - GUI', layout, finalize=True)
263
+ self.window['spk_id'].bind('<Return>', '')
264
+ self.window['samplerate'].bind('<Return>', '')
265
+ self.event_handler()
266
+
267
+ def event_handler(self):
268
+ '''事件处理'''
269
+ while True: # 事件处理循环
270
+ event, values = self.window.read()
271
+ print('event: ' + event)
272
+ if event == sg.WINDOW_CLOSED: # 如果用户关闭窗口
273
+ self.flag_vc = False
274
+ exit()
275
+ elif event == 'start_vc' and self.flag_vc == False:
276
+ # set values 和界面布局layout顺序一一对应
277
+ self.set_values(values)
278
+ print('crossfade_time:' + str(self.config.crossfade_time))
279
+ print("buffer_num:" + str(self.config.buffer_num))
280
+ print("samplerate:" + str(self.config.samplerate))
281
+ print('block_time:' + str(self.config.block_time))
282
+ print("prefix_pad_length:" + str(self.f_safe_prefix_pad_length))
283
+ print("mix_mode:" + str(self.config.spk_mix_dict))
284
+ print("enhancer:" + str(self.config.use_vocoder_based_enhancer))
285
+ print('using_cuda:' + str(torch.cuda.is_available()))
286
+ self.start_vc()
287
+ elif event == 'spk_id':
288
+ self.config.spk_id = int(values['spk_id'])
289
+ elif event == 'threhold':
290
+ self.config.threhold = values['threhold']
291
+ elif event == 'pitch':
292
+ self.config.f_pitch_change = values['pitch']
293
+ elif event == 'spk_mix':
294
+ self.config.use_spk_mix = values['spk_mix']
295
+ elif event == 'set_spk_mix':
296
+ spk_mix = sg.popup_get_text(message='示例:1:0.3,2:0.5,3:0.2', title="设置混合音色,支持多人")
297
+ if spk_mix != None:
298
+ self.config.spk_mix_dict = eval("{" + spk_mix.replace(',', ',').replace(':', ':') + "}")
299
+ elif event == 'f0_mode':
300
+ self.config.select_pitch_extractor = values['f0_mode']
301
+ elif event == 'use_enhancer':
302
+ self.config.use_vocoder_based_enhancer = values['use_enhancer']
303
+ elif event == 'use_phase_vocoder':
304
+ self.config.use_phase_vocoder = values['use_phase_vocoder']
305
+ elif event == 'load_config' and self.flag_vc == False:
306
+ if self.config.load(values['config_file_dir']):
307
+ self.update_values()
308
+ elif event == 'save_config' and self.flag_vc == False:
309
+ self.set_values(values)
310
+ self.config.save(values['config_file_dir'])
311
+ elif event != 'start_vc' and self.flag_vc == True:
312
+ self.flag_vc = False
313
+
314
+ def set_values(self, values):
315
+ self.set_devices(values["sg_input_device"], values['sg_output_device'])
316
+ self.config.sounddevices = [values["sg_input_device"], values['sg_output_device']]
317
+ self.config.checkpoint_path = values['sg_model']
318
+ self.config.spk_id = int(values['spk_id'])
319
+ self.config.threhold = values['threhold']
320
+ self.config.f_pitch_change = values['pitch']
321
+ self.config.samplerate = int(values['samplerate'])
322
+ self.config.block_time = float(values['block'])
323
+ self.config.crossfade_time = float(values['crossfade'])
324
+ self.config.buffer_num = int(values['buffernum'])
325
+ self.config.select_pitch_extractor = values['f0_mode']
326
+ self.config.use_vocoder_based_enhancer = values['use_enhancer']
327
+ self.config.use_phase_vocoder = values['use_phase_vocoder']
328
+ self.config.use_spk_mix = values['spk_mix']
329
+ self.block_frame = int(self.config.block_time * self.config.samplerate)
330
+ self.crossfade_frame = int(self.config.crossfade_time * self.config.samplerate)
331
+ self.sola_search_frame = int(0.01 * self.config.samplerate)
332
+ self.last_delay_frame = int(0.02 * self.config.samplerate)
333
+ self.input_frames = max(
334
+ self.block_frame + self.crossfade_frame + self.sola_search_frame + 2 * self.last_delay_frame,
335
+ (1 + self.config.buffer_num) * self.block_frame)
336
+ self.f_safe_prefix_pad_length = self.config.block_time * self.config.buffer_num - self.config.crossfade_time - 0.01 - 0.02
337
+
338
+ def update_values(self):
339
+ self.window['sg_model'].update(self.config.checkpoint_path)
340
+ self.window['sg_input_device'].update(self.config.sounddevices[0])
341
+ self.window['sg_output_device'].update(self.config.sounddevices[1])
342
+ self.window['spk_id'].update(self.config.spk_id)
343
+ self.window['threhold'].update(self.config.threhold)
344
+ self.window['pitch'].update(self.config.f_pitch_change)
345
+ self.window['samplerate'].update(self.config.samplerate)
346
+ self.window['spk_mix'].update(self.config.use_spk_mix)
347
+ self.window['block'].update(self.config.block_time)
348
+ self.window['crossfade'].update(self.config.crossfade_time)
349
+ self.window['buffernum'].update(self.config.buffer_num)
350
+ self.window['f0_mode'].update(self.config.select_pitch_extractor)
351
+ self.window['use_enhancer'].update(self.config.use_vocoder_based_enhancer)
352
+
353
+ def start_vc(self):
354
+ '''开始音频转换'''
355
+ torch.cuda.empty_cache()
356
+ self.flag_vc = True
357
+ self.input_wav = np.zeros(self.input_frames, dtype='float32')
358
+ self.sola_buffer = torch.zeros(self.crossfade_frame, device=self.device)
359
+ self.fade_in_window = torch.sin(
360
+ np.pi * torch.arange(0, 1, 1 / self.crossfade_frame, device=self.device) / 2) ** 2
361
+ self.fade_out_window = 1 - self.fade_in_window
362
+ self.svc_model.update_model(self.config.checkpoint_path)
363
+ thread_vc = threading.Thread(target=self.soundinput)
364
+ thread_vc.start()
365
+
366
+ def soundinput(self):
367
+ '''
368
+ 接受音频输入
369
+ '''
370
+ with sd.Stream(callback=self.audio_callback, blocksize=self.block_frame, samplerate=self.config.samplerate,
371
+ dtype='float32'):
372
+ while self.flag_vc:
373
+ time.sleep(self.config.block_time)
374
+ print('Audio block passed.')
375
+ print('ENDing VC')
376
+
377
+ def audio_callback(self, indata: np.ndarray, outdata: np.ndarray, frames, times, status):
378
+ '''
379
+ 音频处理
380
+ '''
381
+ start_time = time.perf_counter()
382
+ print("\nStarting callback")
383
+ self.input_wav[:] = np.roll(self.input_wav, -self.block_frame)
384
+ self.input_wav[-self.block_frame:] = librosa.to_mono(indata.T)
385
+
386
+ # infer
387
+ _audio, _model_sr = self.svc_model.infer(
388
+ self.input_wav,
389
+ self.config.samplerate,
390
+ spk_id=self.config.spk_id,
391
+ threhold=self.config.threhold,
392
+ pitch_adjust=self.config.f_pitch_change,
393
+ use_spk_mix=self.config.use_spk_mix,
394
+ spk_mix_dict=self.config.spk_mix_dict,
395
+ use_enhancer=self.config.use_vocoder_based_enhancer,
396
+ pitch_extractor_type=self.config.select_pitch_extractor,
397
+ safe_prefix_pad_length=self.f_safe_prefix_pad_length,
398
+ )
399
+
400
+ # debug sola
401
+ '''
402
+ _audio, _model_sr = self.input_wav, self.config.samplerate
403
+ rs = int(np.random.uniform(-200,200))
404
+ print('debug_random_shift: ' + str(rs))
405
+ _audio = np.roll(_audio, rs)
406
+ _audio = torch.from_numpy(_audio).to(self.device)
407
+ '''
408
+
409
+ if _model_sr != self.config.samplerate:
410
+ key_str = str(_model_sr) + '_' + str(self.config.samplerate)
411
+ if key_str not in self.resample_kernel:
412
+ self.resample_kernel[key_str] = Resample(_model_sr, self.config.samplerate,
413
+ lowpass_filter_width=128).to(self.device)
414
+ _audio = self.resample_kernel[key_str](_audio)
415
+ temp_wav = _audio[
416
+ - self.block_frame - self.crossfade_frame - self.sola_search_frame - self.last_delay_frame: - self.last_delay_frame]
417
+
418
+ # sola shift
419
+ conv_input = temp_wav[None, None, : self.crossfade_frame + self.sola_search_frame]
420
+ cor_nom = F.conv1d(conv_input, self.sola_buffer[None, None, :])
421
+ cor_den = torch.sqrt(
422
+ F.conv1d(conv_input ** 2, torch.ones(1, 1, self.crossfade_frame, device=self.device)) + 1e-8)
423
+ sola_shift = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
424
+ temp_wav = temp_wav[sola_shift: sola_shift + self.block_frame + self.crossfade_frame]
425
+ print('sola_shift: ' + str(int(sola_shift)))
426
+
427
+ # phase vocoder
428
+ if self.config.use_phase_vocoder:
429
+ temp_wav[: self.crossfade_frame] = phase_vocoder(
430
+ self.sola_buffer,
431
+ temp_wav[: self.crossfade_frame],
432
+ self.fade_out_window,
433
+ self.fade_in_window)
434
+ else:
435
+ temp_wav[: self.crossfade_frame] *= self.fade_in_window
436
+ temp_wav[: self.crossfade_frame] += self.sola_buffer * self.fade_out_window
437
+
438
+ self.sola_buffer = temp_wav[- self.crossfade_frame:]
439
+
440
+ outdata[:] = temp_wav[: - self.crossfade_frame, None].repeat(1, 2).cpu().numpy()
441
+ end_time = time.perf_counter()
442
+ print('infer_time: ' + str(end_time - start_time))
443
+ self.window['infer_time'].update(int((end_time - start_time) * 1000))
444
+
445
+ def get_devices(self, update: bool = True):
446
+ '''获取设备列表'''
447
+ if update:
448
+ sd._terminate()
449
+ sd._initialize()
450
+ devices = sd.query_devices()
451
+ hostapis = sd.query_hostapis()
452
+ for hostapi in hostapis:
453
+ for device_idx in hostapi["devices"]:
454
+ devices[device_idx]["hostapi_name"] = hostapi["name"]
455
+ input_devices = [
456
+ f"{d['name']} ({d['hostapi_name']})"
457
+ for d in devices
458
+ if d["max_input_channels"] > 0
459
+ ]
460
+ output_devices = [
461
+ f"{d['name']} ({d['hostapi_name']})"
462
+ for d in devices
463
+ if d["max_output_channels"] > 0
464
+ ]
465
+ input_devices_indices = [d["index"] for d in devices if d["max_input_channels"] > 0]
466
+ output_devices_indices = [
467
+ d["index"] for d in devices if d["max_output_channels"] > 0
468
+ ]
469
+ return input_devices, output_devices, input_devices_indices, output_devices_indices
470
+
471
+ def set_devices(self, input_device, output_device):
472
+ '''设置输出设备'''
473
+ input_devices, output_devices, input_device_indices, output_device_indices = self.get_devices()
474
+ sd.default.device[0] = input_device_indices[input_devices.index(input_device)]
475
+ sd.default.device[1] = output_device_indices[output_devices.index(output_device)]
476
+ print("input device:" + str(sd.default.device[0]) + ":" + str(input_device))
477
+ print("output device:" + str(sd.default.device[1]) + ":" + str(output_device))
478
+
479
+
480
+
481
+ if __name__ == "__main__":
482
+ i18n = gui_locale.I18nAuto()
483
+ gui = GUI()
DDSP-SVC/gui_diff.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PySimpleGUI as sg
2
+ import sounddevice as sd
3
+ import torch, librosa, threading, pickle
4
+ from enhancer import Enhancer
5
+ import numpy as np
6
+ from torch.nn import functional as F
7
+ from torchaudio.transforms import Resample
8
+ import torchaudio
9
+ from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
10
+ from ddsp.core import upsample
11
+ import time
12
+ from gui_diff_locale import I18nAuto
13
+ from diffusion.infer_gt_mel import DiffGtMel
14
+
15
+
16
+ def phase_vocoder(a, b, fade_out, fade_in):
17
+ fa = torch.fft.rfft(a)
18
+ fb = torch.fft.rfft(b)
19
+ absab = torch.abs(fa) + torch.abs(fb)
20
+ n = a.shape[0]
21
+ if n % 2 == 0:
22
+ absab[1:-1] *= 2
23
+ else:
24
+ absab[1:] *= 2
25
+ phia = torch.angle(fa)
26
+ phib = torch.angle(fb)
27
+ deltaphase = phib - phia
28
+ deltaphase = deltaphase - 2 * np.pi * torch.floor(deltaphase / 2 / np.pi + 0.5)
29
+ w = 2 * np.pi * torch.arange(n // 2 + 1).to(a) + deltaphase
30
+ t = torch.arange(n).unsqueeze(-1).to(a) / n
31
+ result = a * (fade_out ** 2) + b * (fade_in ** 2) + torch.sum(absab * torch.cos(w * t + phia),
32
+ -1) * fade_out * fade_in / n
33
+ return result
34
+
35
+
36
+ class SvcDDSP:
37
+ def __init__(self) -> None:
38
+ self.model = None
39
+ self.units_encoder = None
40
+ self.encoder_type = None
41
+ self.encoder_ckpt = None
42
+ self.enhancer = None
43
+ self.enhancer_type = None
44
+ self.enhancer_ckpt = None
45
+
46
+ def update_model(self, model_path):
47
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
48
+
49
+ # load ddsp model
50
+ if self.model is None or self.model_path != model_path:
51
+ self.model, self.args = load_model(model_path, device=self.device)
52
+ self.model_path = model_path
53
+
54
+ # load units encoder
55
+ if self.units_encoder is None or self.args.data.encoder != self.encoder_type or self.args.data.encoder_ckpt != self.encoder_ckpt:
56
+ if self.args.data.encoder == 'cnhubertsoftfish':
57
+ cnhubertsoft_gate = self.args.data.cnhubertsoft_gate
58
+ else:
59
+ cnhubertsoft_gate = 10
60
+ self.units_encoder = Units_Encoder(
61
+ self.args.data.encoder,
62
+ self.args.data.encoder_ckpt,
63
+ self.args.data.encoder_sample_rate,
64
+ self.args.data.encoder_hop_size,
65
+ cnhubertsoft_gate=cnhubertsoft_gate,
66
+ device=self.device)
67
+ self.encoder_type = self.args.data.encoder
68
+ self.encoder_ckpt = self.args.data.encoder_ckpt
69
+
70
+ # load enhancer
71
+ if self.enhancer is None or self.args.enhancer.type != self.enhancer_type or self.args.enhancer.ckpt != self.enhancer_ckpt:
72
+ self.enhancer = Enhancer(self.args.enhancer.type, self.args.enhancer.ckpt, device=self.device)
73
+ self.enhancer_type = self.args.enhancer.type
74
+ self.enhancer_ckpt = self.args.enhancer.ckpt
75
+
76
+ def infer(self,
77
+ audio,
78
+ sample_rate,
79
+ spk_id=1,
80
+ threhold=-45,
81
+ pitch_adjust=0,
82
+ use_spk_mix=False,
83
+ spk_mix_dict=None,
84
+ use_enhancer=True,
85
+ enhancer_adaptive_key='auto',
86
+ pitch_extractor_type='crepe',
87
+ f0_min=50,
88
+ f0_max=1100,
89
+ safe_prefix_pad_length=0,
90
+ diff_model=None,
91
+ diff_acc=None,
92
+ diff_spk_id=None,
93
+ diff_use=False,
94
+ diff_use_dpm=False,
95
+ k_step=None,
96
+ diff_silence=False,
97
+ audio_alignment=False
98
+ ):
99
+ print("Infering...")
100
+ # load input
101
+ # audio, sample_rate = librosa.load(input_wav, sr=None, mono=True)
102
+ hop_size = self.args.data.block_size * sample_rate / self.args.data.sampling_rate
103
+ if audio_alignment:
104
+ audio_length = len(audio)
105
+ # safe front silence
106
+ if safe_prefix_pad_length > 0.03:
107
+ silence_front = safe_prefix_pad_length - 0.03
108
+ else:
109
+ silence_front = 0
110
+ audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
111
+
112
+ # extract f0
113
+ pitch_extractor = F0_Extractor(
114
+ pitch_extractor_type,
115
+ sample_rate,
116
+ hop_size,
117
+ float(f0_min),
118
+ float(f0_max))
119
+ f0 = pitch_extractor.extract(audio, uv_interp=True, device=self.device, silence_front=silence_front)
120
+ f0 = torch.from_numpy(f0).float().to(self.device).unsqueeze(-1).unsqueeze(0)
121
+ f0 = f0 * 2 ** (float(pitch_adjust) / 12)
122
+
123
+ # extract volume
124
+ volume_extractor = Volume_Extractor(hop_size)
125
+ volume = volume_extractor.extract(audio)
126
+ mask = (volume > 10 ** (float(threhold) / 20)).astype('float')
127
+ mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
128
+ mask = np.array([np.max(mask[n: n + 9]) for n in range(len(mask) - 8)])
129
+ mask = torch.from_numpy(mask).float().to(self.device).unsqueeze(-1).unsqueeze(0)
130
+ mask = upsample(mask, self.args.data.block_size).squeeze(-1)
131
+ volume = torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0)
132
+
133
+ # extract units
134
+ units = self.units_encoder.encode(audio_t, sample_rate, hop_size)
135
+
136
+ # spk_id or spk_mix_dict
137
+ spk_id = torch.LongTensor(np.array([[spk_id]])).to(self.device)
138
+ diff_spk_id = torch.LongTensor(np.array([[diff_spk_id]])).to(self.device)
139
+ dictionary = None
140
+ if use_spk_mix:
141
+ dictionary = spk_mix_dict
142
+
143
+ # forward and return the output
144
+ with torch.no_grad():
145
+ output, _, (s_h, s_n) = self.model(units, f0, volume, spk_id=spk_id, spk_mix_dict=dictionary)
146
+ if diff_use and diff_model is not None:
147
+ output = diff_model.infer(output, f0, units, volume, acc=diff_acc, spk_id=diff_spk_id,
148
+ k_step=k_step, use_dpm=diff_use_dpm, silence_front=silence_front, use_silence=diff_silence,
149
+ spk_mix_dict=dictionary)
150
+ output *= mask
151
+ if use_enhancer and not diff_use:
152
+ output, output_sample_rate = self.enhancer.enhance(
153
+ output,
154
+ self.args.data.sampling_rate,
155
+ f0,
156
+ self.args.data.block_size,
157
+ adaptive_key=enhancer_adaptive_key,
158
+ silence_front=silence_front)
159
+ else:
160
+ output_sample_rate = self.args.data.sampling_rate
161
+
162
+ output = output.squeeze()
163
+ if audio_alignment:
164
+ output[:audio_length]
165
+ return output, output_sample_rate
166
+
167
+
168
+ class Config:
169
+ def __init__(self) -> None:
170
+ self.samplerate = 44100 # Hz
171
+ self.block_time = 1.5 # s
172
+ self.f_pitch_change: float = 0.0 # float(request_form.get("fPitchChange", 0))
173
+ self.spk_id = 1 # 默认说话人。
174
+ self.spk_mix_dict = None # {1:0.5, 2:0.5} 表示1号说话人和2号说话人的音色按照0.5:0.5的比例混合
175
+ self.use_vocoder_based_enhancer = True
176
+ self.use_phase_vocoder = True
177
+ self.checkpoint_path = ''
178
+ self.threhold = -35
179
+ self.buffer_num = 2
180
+ self.crossfade_time = 0.03
181
+ self.select_pitch_extractor = 'harvest' # F0预测器["parselmouth", "dio", "harvest", "crepe"]
182
+ self.use_spk_mix = False
183
+ self.sounddevices = ['', '']
184
+ self.diff_use = False
185
+ self.diff_project = ''
186
+ self.diff_acc = 10
187
+ self.diff_spk_id = 0
188
+ self.k_step = 100
189
+ self.diff_use_dpm = False
190
+ self.diff_silence = False
191
+
192
+ def save(self, path):
193
+ with open(path + '\\config.pkl', 'wb') as f:
194
+ pickle.dump(vars(self), f)
195
+
196
+ def load(self, path) -> bool:
197
+ try:
198
+ with open(path + '\\config.pkl', 'rb') as f:
199
+ self.update(pickle.load(f))
200
+ return True
201
+ except:
202
+ print('config.pkl does not exist')
203
+ return False
204
+
205
+ def update(self, data_dict):
206
+ for key, value in data_dict.items():
207
+ setattr(self, key, value)
208
+
209
+ class GUI:
210
+ def __init__(self) -> None:
211
+ self.config = Config()
212
+ self.flag_vc: bool = False # 变声线程flag
213
+ self.block_frame = 0
214
+ self.crossfade_frame = 0
215
+ self.sola_search_frame = 0
216
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
217
+ self.svc_model: SvcDDSP = SvcDDSP()
218
+ self.diff_model: DiffGtMel = DiffGtMel()
219
+ self.fade_in_window: np.ndarray = None # crossfade计算用numpy数组
220
+ self.fade_out_window: np.ndarray = None # crossfade计算用numpy数组
221
+ self.input_wav: np.ndarray = None # 输入音频规范化后的保存地址
222
+ self.output_wav: np.ndarray = None # 输出音频规范化后的保存地址
223
+ self.sola_buffer: torch.Tensor = None # 保存上一个output的crossfade
224
+ self.f0_mode_list = ["parselmouth", "dio", "harvest", "crepe"] # F0预测器
225
+ self.f_safe_prefix_pad_length: float = 0.0
226
+ self.resample_kernel = {}
227
+ self.launcher() # start
228
+
229
+ def launcher(self):
230
+ '''窗口加载'''
231
+ input_devices, output_devices, _, _ = self.get_devices()
232
+ sg.theme('DarkBlue12') # 设置主题
233
+ # 界面布局
234
+ layout = [
235
+ [sg.Frame(layout=[
236
+ [sg.Input(key='sg_model', default_text='exp\\combsub-test\\model_300000.pt'),
237
+ sg.FileBrowse(i18n('选择模型文件'), key='choose_model')]
238
+ ], title=i18n('模型:.pt格式(自动识别同目录下config.yaml)')),
239
+ sg.Frame(layout=[
240
+ [sg.Text(i18n('选择配置文件所在目录')), sg.Input(key='config_file_dir', default_text='exp'),
241
+ sg.FolderBrowse(i18n('打开文件夹'), key='choose_config')],
242
+ [sg.Button(i18n('读取配置文件'), key='load_config'),
243
+ sg.Button(i18n('保存配置文件'), key='save_config')]
244
+ ], title=i18n('快速配置文件'))
245
+ ],
246
+ [sg.Frame(layout=[
247
+ [sg.Text(i18n("输入设备")),
248
+ sg.Combo(input_devices, key='sg_input_device', default_value=input_devices[sd.default.device[0]],
249
+ enable_events=True)],
250
+ [sg.Text(i18n("输出设备")),
251
+ sg.Combo(output_devices, key='sg_output_device', default_value=output_devices[sd.default.device[1]],
252
+ enable_events=True)]
253
+ ], title=i18n('音频设备'))
254
+ ],
255
+ [sg.Frame(layout=[
256
+ [sg.Text(i18n("说话人id")), sg.Input(key='spk_id', default_text='1', size=8)],
257
+ [sg.Text(i18n("响应阈值")),
258
+ sg.Slider(range=(-60, 0), orientation='h', key='threhold', resolution=1, default_value=-45,
259
+ enable_events=True)],
260
+ [sg.Text(i18n("变调")),
261
+ sg.Slider(range=(-24, 24), orientation='h', key='pitch', resolution=1, default_value=0,
262
+ enable_events=True)],
263
+ [sg.Text(i18n("采样率")), sg.Input(key='samplerate', default_text='44100', size=8)],
264
+ [sg.Checkbox(text=i18n('启用捏音色功能'), default=False, key='spk_mix', enable_events=True),
265
+ sg.Button(i18n("设置混合音色"), key='set_spk_mix')]
266
+ ], title=i18n('普通设置')),
267
+ sg.Frame(layout=[
268
+ [sg.Text(i18n("音频切分大小")),
269
+ sg.Slider(range=(0.05, 3.0), orientation='h', key='block', resolution=0.01, default_value=0.5,
270
+ enable_events=True)],
271
+ [sg.Text(i18n("交叉淡化时长")),
272
+ sg.Slider(range=(0.01, 0.15), orientation='h', key='crossfade', resolution=0.01,
273
+ default_value=0.04, enable_events=True)],
274
+ [sg.Text(i18n("使用历史区块数量")),
275
+ sg.Slider(range=(1, 20), orientation='h', key='buffernum', resolution=1, default_value=3,
276
+ enable_events=True)],
277
+ [sg.Text(i18n("f0预测模式")),
278
+ sg.Combo(values=self.f0_mode_list, key='f0_mode', default_value=self.f0_mode_list[2],
279
+ enable_events=True)],
280
+ [sg.Checkbox(text=i18n('启用增强器'), default=True, key='use_enhancer', enable_events=True),
281
+ sg.Checkbox(text=i18n('启用相位声码器'), default=False, key='use_phase_vocoder',
282
+ enable_events=True)]
283
+ ], title=i18n('性能设置')),
284
+ sg.Frame(layout=[
285
+ [sg.Text(i18n("扩散模型文件"))],
286
+ [sg.Input(key='diff_project', default_text='exp\\diffusion-test\\model_400000.pt'),
287
+ sg.FileBrowse(i18n('选择模型文件'), key='choose_model')],
288
+ [sg.Text(i18n("扩散说话人id")), sg.Input(key='diff_spk_id', default_text='1', size=18)],
289
+ [sg.Text(i18n("扩散深度")), sg.Input(key='k_step', default_text='120', size=18)],
290
+ [sg.Text(i18n("扩散加速")), sg.Input(key='diff_acc', default_text='20', size=18)],
291
+ [sg.Checkbox(text=i18n('启用DPMs(推荐)'), default=False, key='diff_use_dpm', enable_events=True)],
292
+ [sg.Checkbox(text=i18n('启用扩散'), default=True, key='diff_use', enable_events=True),
293
+ sg.Checkbox(text=i18n('不扩散安全区(加速但损失效果)'), default=False, key='diff_silence', enable_events=True)]
294
+ ], title=i18n('扩散设置')),
295
+ ],
296
+ [sg.Button(i18n("开始音频转换"), key="start_vc"), sg.Button(i18n("停止音频转换"), key="stop_vc"),
297
+ sg.Text(i18n('推理所用时间(ms):')), sg.Text('0', key='infer_time')]
298
+ ]
299
+
300
+ # 创造窗口
301
+ self.window = sg.Window('DDSP - GUI', layout, finalize=True)
302
+ self.window['spk_id'].bind('<Return>', '')
303
+ self.window['samplerate'].bind('<Return>', '')
304
+ self.window['diff_spk_id'].bind('<Return>', '')
305
+ self.window['k_step'].bind('<Return>', '')
306
+ self.window['diff_acc'].bind('<Return>', '')
307
+ self.event_handler()
308
+
309
+ def event_handler(self):
310
+ '''事件处理'''
311
+ while True: # 事件处理循环
312
+ event, values = self.window.read()
313
+ if event == sg.WINDOW_CLOSED: # 如果用户关闭窗口
314
+ self.flag_vc = False
315
+ exit()
316
+
317
+ print('event: ' + event)
318
+
319
+ if event == 'start_vc' and self.flag_vc == False:
320
+ # set values 和界面布局layout顺序一一对应
321
+ self.set_values(values)
322
+ print('crossfade_time:' + str(self.config.crossfade_time))
323
+ print("buffer_num:" + str(self.config.buffer_num))
324
+ print("samplerate:" + str(self.config.samplerate))
325
+ print('block_time:' + str(self.config.block_time))
326
+ print("prefix_pad_length:" + str(self.f_safe_prefix_pad_length))
327
+ print("mix_mode:" + str(self.config.spk_mix_dict))
328
+ print("enhancer:" + str(self.config.use_vocoder_based_enhancer))
329
+ print("diffusion:" + str(self.config.diff_use))
330
+ print('using_cuda:' + str(torch.cuda.is_available()))
331
+ self.start_vc()
332
+ elif event == 'k_step':
333
+ if 1 <= int(values['k_step']) <= 1000:
334
+ self.config.k_step = int(values['k_step'])
335
+ else:
336
+ self.window['k_step'].update(1000)
337
+ elif event == 'diff_acc':
338
+ if self.config.k_step < int(values['diff_acc']):
339
+ self.config.diff_acc = int(self.config.k_step / 4)
340
+ else:
341
+ self.config.diff_acc = int(values['diff_acc'])
342
+ elif event == 'diff_spk_id':
343
+ self.config.diff_spk_id = int(values['diff_spk_id'])
344
+ elif event == 'diff_use':
345
+ self.config.diff_use = values['diff_use']
346
+ self.window['use_enhancer'].update(False)
347
+ self.config.use_vocoder_based_enhancer=False
348
+ elif event == 'diff_silence':
349
+ self.config.diff_silence = values['diff_silence']
350
+ elif event == 'diff_use_dpm':
351
+ self.config.diff_use_dpm = values['diff_use_dpm']
352
+ elif event == 'spk_id':
353
+ self.config.spk_id = int(values['spk_id'])
354
+ elif event == 'threhold':
355
+ self.config.threhold = values['threhold']
356
+ elif event == 'pitch':
357
+ self.config.f_pitch_change = values['pitch']
358
+ elif event == 'spk_mix':
359
+ self.config.use_spk_mix = values['spk_mix']
360
+ elif event == 'set_spk_mix':
361
+ spk_mix = sg.popup_get_text(message='示例:1:0.3,2:0.5,3:0.2', title="设置混合音色,支持多人")
362
+ if spk_mix != None:
363
+ self.config.spk_mix_dict = eval("{" + spk_mix.replace(',', ',').replace(':', ':') + "}")
364
+ elif event == 'f0_mode':
365
+ self.config.select_pitch_extractor = values['f0_mode']
366
+ elif event == 'use_enhancer':
367
+ self.config.use_vocoder_based_enhancer = values['use_enhancer']
368
+ self.window['diff_use'].update(False)
369
+ self.config.diff_use = False
370
+ elif event == 'use_phase_vocoder':
371
+ self.config.use_phase_vocoder = values['use_phase_vocoder']
372
+ elif event == 'load_config' and self.flag_vc == False:
373
+ if self.config.load(values['config_file_dir']):
374
+ self.update_values()
375
+ elif event == 'save_config' and self.flag_vc == False:
376
+ self.set_values(values)
377
+ self.config.save(values['config_file_dir'])
378
+ elif event != 'start_vc' and self.flag_vc == True:
379
+ self.flag_vc = False
380
+
381
+ def set_values(self, values):
382
+ self.set_devices(values["sg_input_device"], values['sg_output_device'])
383
+ self.config.sounddevices = [values["sg_input_device"], values['sg_output_device']]
384
+ self.config.checkpoint_path = values['sg_model']
385
+ self.config.spk_id = int(values['spk_id'])
386
+ self.config.threhold = values['threhold']
387
+ self.config.f_pitch_change = values['pitch']
388
+ self.config.samplerate = int(values['samplerate'])
389
+ self.config.block_time = float(values['block'])
390
+ self.config.crossfade_time = float(values['crossfade'])
391
+ self.config.buffer_num = int(values['buffernum'])
392
+ self.config.select_pitch_extractor = values['f0_mode']
393
+ self.config.use_vocoder_based_enhancer = values['use_enhancer']
394
+ self.config.use_phase_vocoder = values['use_phase_vocoder']
395
+ self.config.use_spk_mix = values['spk_mix']
396
+ self.config.diff_use = values['diff_use']
397
+ self.config.diff_silence = values['diff_silence']
398
+ self.config.diff_use_dpm = values['diff_use_dpm']
399
+ self.config.diff_project = values['diff_project']
400
+ self.config.diff_acc = int(values['diff_acc'])
401
+ self.config.diff_spk_id = int(values['diff_spk_id'])
402
+ self.config.k_step = int(values['k_step'])
403
+ self.block_frame = int(self.config.block_time * self.config.samplerate)
404
+ self.crossfade_frame = int(self.config.crossfade_time * self.config.samplerate)
405
+ self.sola_search_frame = int(0.01 * self.config.samplerate)
406
+ self.last_delay_frame = int(0.02 * self.config.samplerate)
407
+ self.input_frames = max(
408
+ self.block_frame + self.crossfade_frame + self.sola_search_frame + 2 * self.last_delay_frame,
409
+ (1 + self.config.buffer_num) * self.block_frame)
410
+ self.f_safe_prefix_pad_length = self.config.block_time * self.config.buffer_num - self.config.crossfade_time - 0.01 - 0.02
411
+
412
+ def update_values(self):
413
+ self.window['sg_model'].update(self.config.checkpoint_path)
414
+ self.window['sg_input_device'].update(self.config.sounddevices[0])
415
+ self.window['sg_output_device'].update(self.config.sounddevices[1])
416
+ self.window['spk_id'].update(self.config.spk_id)
417
+ self.window['threhold'].update(self.config.threhold)
418
+ self.window['pitch'].update(self.config.f_pitch_change)
419
+ self.window['samplerate'].update(self.config.samplerate)
420
+ self.window['spk_mix'].update(self.config.use_spk_mix)
421
+ self.window['block'].update(self.config.block_time)
422
+ self.window['crossfade'].update(self.config.crossfade_time)
423
+ self.window['buffernum'].update(self.config.buffer_num)
424
+ self.window['f0_mode'].update(self.config.select_pitch_extractor)
425
+ self.window['use_enhancer'].update(self.config.use_vocoder_based_enhancer)
426
+ self.window['diff_use'].update(self.config.diff_use)
427
+ self.window['diff_silence'].update(self.config.diff_silence)
428
+ self.window['diff_use_dpm'].update(self.config.diff_use_dpm)
429
+ self.window['diff_project'].update(self.config.diff_project)
430
+ self.window['diff_acc'].update(self.config.diff_acc)
431
+ self.window['diff_spk_id'].update(self.config.diff_spk_id)
432
+ self.window['k_step'].update(self.config.k_step)
433
+
434
+ def start_vc(self):
435
+ '''开始音频转换'''
436
+ torch.cuda.empty_cache()
437
+ self.flag_vc = True
438
+ self.input_wav = np.zeros(self.input_frames, dtype='float32')
439
+ self.sola_buffer = torch.zeros(self.crossfade_frame, device=self.device)
440
+ self.fade_in_window = torch.sin(
441
+ np.pi * torch.arange(0, 1, 1 / self.crossfade_frame, device=self.device) / 2) ** 2
442
+ self.fade_out_window = 1 - self.fade_in_window
443
+ self.svc_model.update_model(self.config.checkpoint_path)
444
+ if self.config.diff_use:
445
+ self.diff_model.flush_model(self.config.diff_project, ddsp_config=self.svc_model.args)
446
+ thread_vc = threading.Thread(target=self.soundinput)
447
+ thread_vc.start()
448
+
449
+ def soundinput(self):
450
+ '''
451
+ 接受音频输入
452
+ '''
453
+ with sd.Stream(callback=self.audio_callback, blocksize=self.block_frame, samplerate=self.config.samplerate,
454
+ dtype='float32'):
455
+ while self.flag_vc:
456
+ time.sleep(self.config.block_time)
457
+ print('Audio block passed.')
458
+ print('ENDing VC')
459
+
460
+ def audio_callback(self, indata: np.ndarray, outdata: np.ndarray, frames, times, status):
461
+ '''
462
+ 音频处理
463
+ '''
464
+ start_time = time.perf_counter()
465
+ print("\nStarting callback")
466
+ self.input_wav[:] = np.roll(self.input_wav, -self.block_frame)
467
+ self.input_wav[-self.block_frame:] = librosa.to_mono(indata.T)
468
+
469
+ # infer
470
+ if self.config.diff_use:
471
+ _diff_model = self.diff_model
472
+ else:
473
+ _diff_model = None
474
+ _audio, _model_sr = self.svc_model.infer(
475
+ self.input_wav,
476
+ self.config.samplerate,
477
+ spk_id=self.config.spk_id,
478
+ threhold=self.config.threhold,
479
+ pitch_adjust=self.config.f_pitch_change,
480
+ use_spk_mix=self.config.use_spk_mix,
481
+ spk_mix_dict=self.config.spk_mix_dict,
482
+ use_enhancer=self.config.use_vocoder_based_enhancer,
483
+ pitch_extractor_type=self.config.select_pitch_extractor,
484
+ safe_prefix_pad_length=self.f_safe_prefix_pad_length,
485
+ diff_model=_diff_model,
486
+ diff_acc=self.config.diff_acc,
487
+ diff_spk_id=self.config.diff_spk_id,
488
+ diff_use=self.config.diff_use,
489
+ diff_use_dpm=self.config.diff_use_dpm,
490
+ k_step=self.config.k_step,
491
+ diff_silence=self.config.diff_silence
492
+ )
493
+
494
+ # debug sola
495
+ '''
496
+ _audio, _model_sr = self.input_wav, self.config.samplerate
497
+ rs = int(np.random.uniform(-200,200))
498
+ print('debug_random_shift: ' + str(rs))
499
+ _audio = np.roll(_audio, rs)
500
+ _audio = torch.from_numpy(_audio).to(self.device)
501
+ '''
502
+
503
+ if _model_sr != self.config.samplerate:
504
+ key_str = str(_model_sr) + '_' + str(self.config.samplerate)
505
+ if key_str not in self.resample_kernel:
506
+ self.resample_kernel[key_str] = Resample(_model_sr, self.config.samplerate,
507
+ lowpass_filter_width=128).to(self.device)
508
+ _audio = self.resample_kernel[key_str](_audio)
509
+ temp_wav = _audio[
510
+ - self.block_frame - self.crossfade_frame - self.sola_search_frame - self.last_delay_frame: - self.last_delay_frame]
511
+
512
+ # sola shift
513
+ conv_input = temp_wav[None, None, : self.crossfade_frame + self.sola_search_frame]
514
+ cor_nom = F.conv1d(conv_input, self.sola_buffer[None, None, :])
515
+ cor_den = torch.sqrt(
516
+ F.conv1d(conv_input ** 2, torch.ones(1, 1, self.crossfade_frame, device=self.device)) + 1e-8)
517
+ sola_shift = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
518
+ temp_wav = temp_wav[sola_shift: sola_shift + self.block_frame + self.crossfade_frame]
519
+ print('sola_shift: ' + str(int(sola_shift)))
520
+
521
+ # phase vocoder
522
+ if self.config.use_phase_vocoder:
523
+ temp_wav[: self.crossfade_frame] = phase_vocoder(
524
+ self.sola_buffer,
525
+ temp_wav[: self.crossfade_frame],
526
+ self.fade_out_window,
527
+ self.fade_in_window)
528
+ else:
529
+ temp_wav[: self.crossfade_frame] *= self.fade_in_window
530
+ temp_wav[: self.crossfade_frame] += self.sola_buffer * self.fade_out_window
531
+
532
+ self.sola_buffer = temp_wav[- self.crossfade_frame:]
533
+
534
+ outdata[:] = temp_wav[: - self.crossfade_frame, None].repeat(1, 2).cpu().numpy()
535
+ end_time = time.perf_counter()
536
+ print('infer_time: ' + str(end_time - start_time))
537
+ self.window['infer_time'].update(int((end_time - start_time) * 1000))
538
+
539
+ def get_devices(self, update: bool = True):
540
+ '''获取设备列表'''
541
+ if update:
542
+ sd._terminate()
543
+ sd._initialize()
544
+ devices = sd.query_devices()
545
+ hostapis = sd.query_hostapis()
546
+ for hostapi in hostapis:
547
+ for device_idx in hostapi["devices"]:
548
+ devices[device_idx]["hostapi_name"] = hostapi["name"]
549
+ input_devices = [
550
+ f"{d['name']} ({d['hostapi_name']})"
551
+ for d in devices
552
+ if d["max_input_channels"] > 0
553
+ ]
554
+ output_devices = [
555
+ f"{d['name']} ({d['hostapi_name']})"
556
+ for d in devices
557
+ if d["max_output_channels"] > 0
558
+ ]
559
+ input_devices_indices = [d["index"] for d in devices if d["max_input_channels"] > 0]
560
+ output_devices_indices = [
561
+ d["index"] for d in devices if d["max_output_channels"] > 0
562
+ ]
563
+ return input_devices, output_devices, input_devices_indices, output_devices_indices
564
+
565
+ def set_devices(self, input_device, output_device):
566
+ '''设置输出设备'''
567
+ input_devices, output_devices, input_device_indices, output_device_indices = self.get_devices()
568
+ sd.default.device[0] = input_device_indices[input_devices.index(input_device)]
569
+ sd.default.device[1] = output_device_indices[output_devices.index(output_device)]
570
+ print("input device:" + str(sd.default.device[0]) + ":" + str(input_device))
571
+ print("output device:" + str(sd.default.device[1]) + ":" + str(output_device))
572
+
573
+
574
+ if __name__ == "__main__":
575
+ i18n = I18nAuto()
576
+ gui = GUI()
DDSP-SVC/gui_diff_locale.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import locale
2
+ '''
3
+ 本地化方式如下所示
4
+ '''
5
+
6
+ LANGUAGE_LIST = ['zh_CN', 'en_US', 'ja_JP']
7
+ LANGUAGE_ALL = {
8
+ 'zh_CN': {
9
+ 'SUPER': 'END',
10
+ 'LANGUAGE': 'zh_CN',
11
+ '选择模型文件': '选择模型文件',
12
+ '模型:.pt格式(自动识别同目录下config.yaml)': '模型:.pt格式(自动识别同目录下config.yaml)',
13
+ '选择配置文件所在目录': '选择配置文件所在目录',
14
+ '打开文件夹': '打开文件夹',
15
+ '读取配置文件': '读取配置文件',
16
+ '保存配置文件': '保存配置文件',
17
+ '快速配置文件': '快速配置文件',
18
+ '输入设备': '输入设备',
19
+ '输出设备': '输出设备',
20
+ '音频设备': '音频设备',
21
+ '说话人id': '说话人id',
22
+ '响应阈值': '响应阈值',
23
+ '变调': '变调',
24
+ '采样率': '采样率',
25
+ '启用捏音色功能': '启用捏音色功能',
26
+ '设置混合音色': '设置混合音色',
27
+ '普通设置': '普通设置',
28
+ '音频切分大小': '音频切分大小',
29
+ '交叉淡化时长': '交叉淡化时长',
30
+ '使用历史区块数量': '使用历史区块数量',
31
+ 'f0预测模式': 'f0预测模式',
32
+ '启用增强器': '启用增强器',
33
+ '启用相位声码器': '启用相位声码器',
34
+ '性能设置': '性能设置',
35
+ '开始音频转换': '开始音频转换',
36
+ '停止音频转换': '停止音频转换',
37
+ '推理所用时间(ms):': '推理所用时间(ms):',
38
+ '扩散设置': '扩散设置',
39
+ '启用扩散': '启用扩散',
40
+ '扩散加速': '扩散加速',
41
+ '扩散深度': '扩散深度',
42
+ '扩散说话人id': '扩散说话人id',
43
+ '扩散模型文件': '扩散模型文件',
44
+ '不扩散安全区(加速但损失效果)': '不扩散安全区(加速但损失效果)',
45
+ '启用DPMs(推荐)': '启用DPMs(推荐)'
46
+ },
47
+ 'en_US': {
48
+ 'SUPER': 'zh_CN',
49
+ 'LANGUAGE': 'en_US',
50
+ '选择模型文件': 'Select Model File',
51
+ '模型:.pt格式(自动识别同目录下config.yaml)': 'Model:.pt format(Auto ust config.yaml in here)',
52
+ '选择配置文件所在目录': 'Select the configuration file directory',
53
+ '打开文件夹': 'Open folder',
54
+ '读取配置文件': 'Read config file',
55
+ '保存配置文件': 'Save config file',
56
+ '快速配置文件': 'Fast config file',
57
+ '输入设备': 'Input device',
58
+ '输出设备': 'Output device',
59
+ '音频设备': 'Audio devices',
60
+ '说话人id': 'Speaker ID',
61
+ '响应阈值': 'Response threshold',
62
+ '变调': 'Pitch',
63
+ '采样率': 'Sampling rate',
64
+ '启用捏音色功能': 'Enable Mix Speaker',
65
+ '设置混合音色': 'Mix Speaker',
66
+ '普通设置': 'Normal Settings',
67
+ '音频切分大小': 'Segmentation size',
68
+ '交叉淡化时长': 'Cross fade duration',
69
+ '使用历史区块数量': 'Historical blocks used',
70
+ 'f0预测模式': 'f0Extractor',
71
+ '启用增强器': 'Enable Enhancer',
72
+ '启用相位声码器': 'Enable Phase Vocoder',
73
+ '性能设置': 'Performance settings',
74
+ '开始音频转换': 'Start conversion',
75
+ '停止音频转换': 'Stop conversion',
76
+ '推理所用时间(ms):': 'Inference time(ms):',
77
+ '扩散设置': '扩散设置',
78
+ '启用扩散': '启用扩散',
79
+ '扩散加速': '扩散加速',
80
+ '扩散深度': '扩散深度',
81
+ '扩散说话人id': '扩散说话人id',
82
+ '扩散模型文件': '扩散模型文件',
83
+ '不扩散安全区(加速但损失效果)': '不扩散安全区(加速但损失效果)',
84
+ '启用DPMs(推荐)': '启用DPMs(推荐)'
85
+ },
86
+ 'ja_JP': {
87
+ 'SUPER': 'zh_CN',
88
+ 'LANGUAGE': 'ja_JP',
89
+ '选择模型文件': 'モデルを選択',
90
+ '模型:.pt格式(自动识别同目录下config.yaml)': 'モデル:.pt形式(同じディレクトリにあるconfig.yamlを自動認識します)',
91
+ '选择配置文件所在目录': '設定ファイルを選択',
92
+ '打开文件夹': 'フォルダを開く',
93
+ '读取配置文件': '設定ファイルを読み込む',
94
+ '保存配置文件': '設定ファイルを保存',
95
+ '快速配置文件': '設定プロファイル',
96
+ '输入设备': '入力デバイス',
97
+ '输出设备': '出力デバイス',
98
+ '音频设备': '音声デバイス',
99
+ '说话人id': '話者ID',
100
+ '响应阈值': '応答時の閾値',
101
+ '变调': '音程',
102
+ '采样率': 'サンプリングレート',
103
+ '启用捏音色功能': 'ミキシングを有効化',
104
+ '设置混合音色': 'ミキシング',
105
+ '普通设置': '通常設定',
106
+ '音频切分大小': 'セグメンテーションのサイズ',
107
+ '交叉淡化���长': 'クロスフェードの間隔',
108
+ '使用历史区块数量': '使用するヒストリカルブロック数',
109
+ 'f0预测模式': 'f0予測モデル',
110
+ '启用增强器': 'Enhancerを有効化',
111
+ '启用相位声码器': 'フェーズボコーダを有効化',
112
+ '性能设置': 'パフォーマンスの設定',
113
+ '开始音频转换': '変換開始',
114
+ '停止音频转换': '変換停止',
115
+ '推理所用时间(ms):': '推論時間(ms):',
116
+ '扩散设置': '扩散设置',
117
+ '启用扩散': '启用扩散',
118
+ '扩散加速': '扩散加速',
119
+ '扩散深度': '扩散深度',
120
+ '扩散说话人id': '扩散说话人id',
121
+ '扩散模型文件': '扩散模型文件',
122
+ '不扩散安全区(加速但损失效果)': '不扩散安全区(加速但损失效果)',
123
+ '启用DPMs(推荐)': '启用DPMs(推荐)'
124
+ }
125
+ }
126
+
127
+
128
+ class I18nAuto:
129
+ def __init__(self, language=None):
130
+ self.language_list = LANGUAGE_LIST
131
+ self.language_all = LANGUAGE_ALL
132
+ self.language_map = {}
133
+ if language is None:
134
+ language = 'auto'
135
+ if language == 'auto':
136
+ language = locale.getdefaultlocale()[0]
137
+ if language not in self.language_list:
138
+ language = 'zh_CN'
139
+ self.language = language
140
+ super_language_list = []
141
+ while self.language_all[language]['SUPER'] != 'END':
142
+ super_language_list.append(language)
143
+ language = self.language_all[language]['SUPER']
144
+ super_language_list.append('zh_CN')
145
+ super_language_list.reverse()
146
+ for _lang in super_language_list:
147
+ self.read_language(self.language_all[_lang])
148
+
149
+ def read_language(self, lang_dict: dict):
150
+ for _key in lang_dict.keys():
151
+ self.language_map[_key] = lang_dict[_key]
152
+
153
+ def __call__(self, key):
154
+ return self.language_map[key]
DDSP-SVC/gui_locale.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import locale
2
+ '''
3
+ 本地化方式如下所示
4
+ '''
5
+
6
+ LANGUAGE_LIST = ['zh_CN', 'en_US', 'ja_JP']
7
+ LANGUAGE_ALL = {
8
+ 'zh_CN': {
9
+ 'SUPER': 'END',
10
+ 'LANGUAGE': 'zh_CN',
11
+ '选择模型文件': '选择模型文件',
12
+ '模型:.pt格式(自动识别同目录下config.yaml)': '模型:.pt格式(自动识别同目录下config.yaml)',
13
+ '选择配置文件所在目录': '选择配置文件所在目录',
14
+ '打开文件夹': '打开文件夹',
15
+ '读取配置文件': '读取配置文件',
16
+ '保存配置文件': '保存配置文件',
17
+ '快速配置文件': '快速配置文件',
18
+ '输入设备': '输入设备',
19
+ '输出设备': '输出设备',
20
+ '音频设备': '音频设备',
21
+ '说话人id': '说话人id',
22
+ '响应阈值': '响应阈值',
23
+ '变调': '变调',
24
+ '采样率': '采样率',
25
+ '启用捏音色功能': '启用捏音色功能',
26
+ '设置混合音色': '设置混合音色',
27
+ '普通设置': '普通设置',
28
+ '音频切分大小': '音频切分大小',
29
+ '交叉淡化时长': '交叉淡化时长',
30
+ '使用历史区块数量': '使用历史区块数量',
31
+ 'f0预测模式': 'f0预测模式',
32
+ '启用增强器': '启用增强器',
33
+ '启用相位声码器': '启用相位声码器',
34
+ '性能设置': '性能设置',
35
+ '开始音频转换': '开始音频转换',
36
+ '停止音频转换': '停止音频转换',
37
+ '推理所用时间(ms):': '推理所用时间(ms):'
38
+ },
39
+ 'en_US': {
40
+ 'SUPER': 'zh_CN',
41
+ 'LANGUAGE': 'en_US',
42
+ '选择模型文件': 'Select Model File',
43
+ '模型:.pt格式(自动识别同目录下config.yaml)': 'Model:.pt format(Auto ust config.yaml in here)',
44
+ '选择配置文件所在目录': 'Select the configuration file directory',
45
+ '打开文件夹': 'Open folder',
46
+ '读取配置文件': 'Read config file',
47
+ '保存配置文件': 'Save config file',
48
+ '快速配置文件': 'Fast config file',
49
+ '输入设备': 'Input device',
50
+ '输出设备': 'Output device',
51
+ '音频设备': 'Audio devices',
52
+ '说话人id': 'Speaker ID',
53
+ '响应阈值': 'Response threshold',
54
+ '变调': 'Pitch',
55
+ '采样率': 'Sampling rate',
56
+ '启用捏音色功能': 'Enable Mix Speaker',
57
+ '设置混合音色': 'Mix Speaker',
58
+ '普通设置': 'Normal Settings',
59
+ '音频切分大小': 'Segmentation size',
60
+ '交叉淡化时长': 'Cross fade duration',
61
+ '使用历史区块数量': 'Historical blocks used',
62
+ 'f0预测模式': 'f0Extractor',
63
+ '启用增强器': 'Enable Enhancer',
64
+ '启用相位声码器': 'Enable Phase Vocoder',
65
+ '性能设置': 'Performance settings',
66
+ '开始音频转换': 'Start conversion',
67
+ '停止音频转换': 'Stop conversion',
68
+ '推理所用时间(ms):': 'Inference time(ms):'
69
+ },
70
+ 'ja_JP': {
71
+ 'SUPER': 'zh_CN',
72
+ 'LANGUAGE': 'ja_JP',
73
+ '选择模型文件': 'モデルを選択',
74
+ '模型:.pt格式(自动识别同目录下config.yaml)': 'モデル:.pt形式(同じディレクトリにあるconfig.yamlを自動認識します)',
75
+ '选择配置文件所在目录': '設定ファイルを選択',
76
+ '打开文件夹': 'フォルダを開く',
77
+ '读取配置文件': '設定ファイルを読み込む',
78
+ '保存配置文件': '設定ファイルを保存',
79
+ '快速配置文件': '設定プロファイル',
80
+ '输入设备': '入力デバイス',
81
+ '输出设备': '出力デバイス',
82
+ '音频设备': '音声デバイス',
83
+ '说话人id': '話者ID',
84
+ '响应阈值': '応答時の閾値',
85
+ '变调': '音程',
86
+ '采样率': 'サンプリングレート',
87
+ '启用捏音色功能': 'ミキシングを有効化',
88
+ '设置混合音色': 'ミキシング',
89
+ '普通设置': '通常設定',
90
+ '音频切分大小': 'セグメンテーションのサイズ',
91
+ '交叉淡化时长': 'クロスフェードの間隔',
92
+ '使用历史区块数量': '使用するヒストリカルブロック数',
93
+ 'f0预测模式': 'f0予測モデル',
94
+ '启用增强器': 'Enhancerを有効化',
95
+ '启用相位声码器': 'フェーズボコーダを有効化',
96
+ '性能设置': 'パフォーマンスの設定',
97
+ '开始音频转换': '変換開始',
98
+ '停止音频转换': '変換停止',
99
+ '推理所用时间(ms):': '推論時間(ms):'
100
+ }
101
+ }
102
+
103
+
104
+ class I18nAuto:
105
+ def __init__(self, language=None):
106
+ self.language_list = LANGUAGE_LIST
107
+ self.language_all = LANGUAGE_ALL
108
+ self.language_map = {}
109
+ if language is None:
110
+ language = 'auto'
111
+ if language == 'auto':
112
+ language = locale.getdefaultlocale()[0]
113
+ if language not in self.language_list:
114
+ language = 'zh_CN'
115
+ self.language = language
116
+ super_language_list = []
117
+ while self.language_all[language]['SUPER'] != 'END':
118
+ super_language_list.append(language)
119
+ language = self.language_all[language]['SUPER']
120
+ super_language_list.append('zh_CN')
121
+ super_language_list.reverse()
122
+ for _lang in super_language_list:
123
+ self.read_language(self.language_all[_lang])
124
+
125
+ def read_language(self, lang_dict: dict):
126
+ for _key in lang_dict.keys():
127
+ self.language_map[_key] = lang_dict[_key]
128
+
129
+ def __call__(self, key):
130
+ return self.language_map[key]
DDSP-SVC/logger/__init__.py ADDED
File without changes
DDSP-SVC/logger/saver.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ author: wayn391@mastertones
3
+ '''
4
+
5
+ import os
6
+ import json
7
+ import time
8
+ import yaml
9
+ import datetime
10
+ import torch
11
+ import matplotlib.pyplot as plt
12
+ from . import utils
13
+ from torch.utils.tensorboard import SummaryWriter
14
+
15
+ class Saver(object):
16
+ def __init__(
17
+ self,
18
+ args,
19
+ initial_global_step=-1):
20
+
21
+ self.expdir = args.env.expdir
22
+ self.sample_rate = args.data.sampling_rate
23
+
24
+ # cold start
25
+ self.global_step = initial_global_step
26
+ self.init_time = time.time()
27
+ self.last_time = time.time()
28
+
29
+ # makedirs
30
+ os.makedirs(self.expdir, exist_ok=True)
31
+
32
+ # path
33
+ self.path_log_info = os.path.join(self.expdir, 'log_info.txt')
34
+
35
+ # ckpt
36
+ os.makedirs(self.expdir, exist_ok=True)
37
+
38
+ # writer
39
+ self.writer = SummaryWriter(os.path.join(self.expdir, 'logs'))
40
+
41
+ # save config
42
+ path_config = os.path.join(self.expdir, 'config.yaml')
43
+ with open(path_config, "w") as out_config:
44
+ yaml.dump(dict(args), out_config)
45
+
46
+
47
+ def log_info(self, msg):
48
+ '''log method'''
49
+ if isinstance(msg, dict):
50
+ msg_list = []
51
+ for k, v in msg.items():
52
+ tmp_str = ''
53
+ if isinstance(v, int):
54
+ tmp_str = '{}: {:,}'.format(k, v)
55
+ else:
56
+ tmp_str = '{}: {}'.format(k, v)
57
+
58
+ msg_list.append(tmp_str)
59
+ msg_str = '\n'.join(msg_list)
60
+ else:
61
+ msg_str = msg
62
+
63
+ # dsplay
64
+ print(msg_str)
65
+
66
+ # save
67
+ with open(self.path_log_info, 'a') as fp:
68
+ fp.write(msg_str+'\n')
69
+
70
+ def log_value(self, dict):
71
+ for k, v in dict.items():
72
+ self.writer.add_scalar(k, v, self.global_step)
73
+
74
+ def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5):
75
+ spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1)
76
+ spec = spec_cat[0]
77
+ if isinstance(spec, torch.Tensor):
78
+ spec = spec.cpu().numpy()
79
+ fig = plt.figure(figsize=(12, 9))
80
+ plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
81
+ plt.tight_layout()
82
+ self.writer.add_figure(name, fig, self.global_step)
83
+
84
+ def log_audio(self, dict):
85
+ for k, v in dict.items():
86
+ self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate)
87
+
88
+ def get_interval_time(self, update=True):
89
+ cur_time = time.time()
90
+ time_interval = cur_time - self.last_time
91
+ if update:
92
+ self.last_time = cur_time
93
+ return time_interval
94
+
95
+ def get_total_time(self, to_str=True):
96
+ total_time = time.time() - self.init_time
97
+ if to_str:
98
+ total_time = str(datetime.timedelta(
99
+ seconds=total_time))[:-5]
100
+ return total_time
101
+
102
+ def save_model(
103
+ self,
104
+ model,
105
+ optimizer,
106
+ name='model',
107
+ postfix='',
108
+ to_json=False):
109
+ # path
110
+ if postfix:
111
+ postfix = '_' + postfix
112
+ path_pt = os.path.join(
113
+ self.expdir , name+postfix+'.pt')
114
+
115
+ # check
116
+ print(' [*] model checkpoint saved: {}'.format(path_pt))
117
+
118
+ # save
119
+ torch.save({
120
+ 'global_step': self.global_step,
121
+ 'model': model.state_dict(),
122
+ 'optimizer': optimizer.state_dict()}, path_pt)
123
+
124
+ # to json
125
+ if to_json:
126
+ path_json = os.path.join(
127
+ self.expdir , name+'.json')
128
+ utils.to_json(path_params, path_json)
129
+
130
+ def delete_model(self, name='model', postfix=''):
131
+ # path
132
+ if postfix:
133
+ postfix = '_' + postfix
134
+ path_pt = os.path.join(
135
+ self.expdir , name+postfix+'.pt')
136
+
137
+ # delete
138
+ if os.path.exists(path_pt):
139
+ os.remove(path_pt)
140
+ print(' [*] model checkpoint deleted: {}'.format(path_pt))
141
+
142
+ def global_step_increment(self):
143
+ self.global_step += 1
144
+
145
+
DDSP-SVC/logger/utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import json
4
+ import pickle
5
+ import torch
6
+
7
+ def traverse_dir(
8
+ root_dir,
9
+ extension,
10
+ amount=None,
11
+ str_include=None,
12
+ str_exclude=None,
13
+ is_pure=False,
14
+ is_sort=False,
15
+ is_ext=True):
16
+
17
+ file_list = []
18
+ cnt = 0
19
+ for root, _, files in os.walk(root_dir):
20
+ for file in files:
21
+ if file.endswith(extension):
22
+ # path
23
+ mix_path = os.path.join(root, file)
24
+ pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
25
+
26
+ # amount
27
+ if (amount is not None) and (cnt == amount):
28
+ if is_sort:
29
+ file_list.sort()
30
+ return file_list
31
+
32
+ # check string
33
+ if (str_include is not None) and (str_include not in pure_path):
34
+ continue
35
+ if (str_exclude is not None) and (str_exclude in pure_path):
36
+ continue
37
+
38
+ if not is_ext:
39
+ ext = pure_path.split('.')[-1]
40
+ pure_path = pure_path[:-(len(ext)+1)]
41
+ file_list.append(pure_path)
42
+ cnt += 1
43
+ if is_sort:
44
+ file_list.sort()
45
+ return file_list
46
+
47
+
48
+
49
+ class DotDict(dict):
50
+ def __getattr__(*args):
51
+ val = dict.get(*args)
52
+ return DotDict(val) if type(val) is dict else val
53
+
54
+ __setattr__ = dict.__setitem__
55
+ __delattr__ = dict.__delitem__
56
+
57
+
58
+ def get_network_paras_amount(model_dict):
59
+ info = dict()
60
+ for model_name, model in model_dict.items():
61
+ # all_params = sum(p.numel() for p in model.parameters())
62
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
63
+
64
+ info[model_name] = trainable_params
65
+ return info
66
+
67
+
68
+ def load_config(path_config):
69
+ with open(path_config, "r") as config:
70
+ args = yaml.safe_load(config)
71
+ args = DotDict(args)
72
+ # print(args)
73
+ return args
74
+
75
+
76
+ def to_json(path_params, path_json):
77
+ params = torch.load(path_params, map_location=torch.device('cpu'))
78
+ raw_state_dict = {}
79
+ for k, v in params.items():
80
+ val = v.flatten().numpy().tolist()
81
+ raw_state_dict[k] = val
82
+
83
+ with open(path_json, 'w') as outfile:
84
+ json.dump(raw_state_dict, outfile,indent= "\t")
85
+
86
+
87
+ def convert_tensor_to_numpy(tensor, is_squeeze=True):
88
+ if is_squeeze:
89
+ tensor = tensor.squeeze()
90
+ if tensor.requires_grad:
91
+ tensor = tensor.detach()
92
+ if tensor.is_cuda:
93
+ tensor = tensor.cpu()
94
+ return tensor.numpy()
95
+
96
+
97
+ def load_model(
98
+ expdir,
99
+ model,
100
+ optimizer,
101
+ name='model',
102
+ postfix='',
103
+ device='cpu'):
104
+ if postfix == '':
105
+ postfix = '_' + postfix
106
+ path = os.path.join(expdir, name+postfix)
107
+ path_pt = traverse_dir(expdir, '.pt', is_ext=False)
108
+ global_step = 0
109
+ if len(path_pt) > 0:
110
+ steps = [s[len(path):] for s in path_pt]
111
+ maxstep = max([int(s) if s.isdigit() else 0 for s in steps])
112
+ if maxstep >= 0:
113
+ path_pt = path+str(maxstep)+'.pt'
114
+ else:
115
+ path_pt = path+'best.pt'
116
+ print(' [*] restoring model from', path_pt)
117
+ ckpt = torch.load(path_pt, map_location=torch.device(device))
118
+ global_step = ckpt['global_step']
119
+ model.load_state_dict(ckpt['model'], strict=False)
120
+ if maxstep != 0:
121
+ optimizer.load_state_dict(ckpt['optimizer'])
122
+ return global_step, model, optimizer
DDSP-SVC/main.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ import argparse
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import pyworld as pw
8
+ import parselmouth
9
+ import hashlib
10
+ from ast import literal_eval
11
+ from slicer import Slicer
12
+ from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
13
+ from ddsp.core import upsample
14
+ from enhancer import Enhancer
15
+ from tqdm import tqdm
16
+
17
+ def parse_args(args=None, namespace=None):
18
+ """Parse command-line arguments."""
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument(
21
+ "-m",
22
+ "--model_path",
23
+ type=str,
24
+ required=True,
25
+ help="path to the model file",
26
+ )
27
+ parser.add_argument(
28
+ "-d",
29
+ "--device",
30
+ type=str,
31
+ default=None,
32
+ required=False,
33
+ help="cpu or cuda, auto if not set")
34
+ parser.add_argument(
35
+ "-i",
36
+ "--input",
37
+ type=str,
38
+ required=True,
39
+ help="path to the input audio file",
40
+ )
41
+ parser.add_argument(
42
+ "-o",
43
+ "--output",
44
+ type=str,
45
+ required=True,
46
+ help="path to the output audio file",
47
+ )
48
+ parser.add_argument(
49
+ "-id",
50
+ "--spk_id",
51
+ type=str,
52
+ required=False,
53
+ default=1,
54
+ help="speaker id (for multi-speaker model) | default: 1",
55
+ )
56
+ parser.add_argument(
57
+ "-mix",
58
+ "--spk_mix_dict",
59
+ type=str,
60
+ required=False,
61
+ default="None",
62
+ help="mix-speaker dictionary (for multi-speaker model) | default: None",
63
+ )
64
+ parser.add_argument(
65
+ "-k",
66
+ "--key",
67
+ type=str,
68
+ required=False,
69
+ default=0,
70
+ help="key changed (number of semitones) | default: 0",
71
+ )
72
+ parser.add_argument(
73
+ "-e",
74
+ "--enhance",
75
+ type=str,
76
+ required=False,
77
+ default='true',
78
+ help="true or false | default: true",
79
+ )
80
+ parser.add_argument(
81
+ "-pe",
82
+ "--pitch_extractor",
83
+ type=str,
84
+ required=False,
85
+ default='crepe',
86
+ help="pitch extrator type: parselmouth, dio, harvest, crepe (default)",
87
+ )
88
+ parser.add_argument(
89
+ "-fmin",
90
+ "--f0_min",
91
+ type=str,
92
+ required=False,
93
+ default=50,
94
+ help="min f0 (Hz) | default: 50",
95
+ )
96
+ parser.add_argument(
97
+ "-fmax",
98
+ "--f0_max",
99
+ type=str,
100
+ required=False,
101
+ default=1100,
102
+ help="max f0 (Hz) | default: 1100",
103
+ )
104
+ parser.add_argument(
105
+ "-th",
106
+ "--threhold",
107
+ type=str,
108
+ required=False,
109
+ default=-60,
110
+ help="response threhold (dB) | default: -60",
111
+ )
112
+ parser.add_argument(
113
+ "-eak",
114
+ "--enhancer_adaptive_key",
115
+ type=str,
116
+ required=False,
117
+ default=0,
118
+ help="adapt the enhancer to a higher vocal range (number of semitones) | default: 0",
119
+ )
120
+ return parser.parse_args(args=args, namespace=namespace)
121
+
122
+
123
+ def split(audio, sample_rate, hop_size, db_thresh = -40, min_len = 5000):
124
+ slicer = Slicer(
125
+ sr=sample_rate,
126
+ threshold=db_thresh,
127
+ min_length=min_len)
128
+ chunks = dict(slicer.slice(audio))
129
+ result = []
130
+ for k, v in chunks.items():
131
+ tag = v["split_time"].split(",")
132
+ if tag[0] != tag[1]:
133
+ start_frame = int(int(tag[0]) // hop_size)
134
+ end_frame = int(int(tag[1]) // hop_size)
135
+ if end_frame > start_frame:
136
+ result.append((
137
+ start_frame,
138
+ audio[int(start_frame * hop_size) : int(end_frame * hop_size)]))
139
+ return result
140
+
141
+
142
+ def cross_fade(a: np.ndarray, b: np.ndarray, idx: int):
143
+ result = np.zeros(idx + b.shape[0])
144
+ fade_len = a.shape[0] - idx
145
+ np.copyto(dst=result[:idx], src=a[:idx])
146
+ k = np.linspace(0, 1.0, num=fade_len, endpoint=True)
147
+ result[idx: a.shape[0]] = (1 - k) * a[idx:] + k * b[: fade_len]
148
+ np.copyto(dst=result[a.shape[0]:], src=b[fade_len:])
149
+ return result
150
+
151
+
152
+ if __name__ == '__main__':
153
+ # parse commands
154
+ cmd = parse_args()
155
+
156
+ #device = 'cpu'
157
+ device = cmd.device
158
+ if device is None:
159
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
160
+
161
+ # load ddsp model
162
+ model, args = load_model(cmd.model_path, device=device)
163
+
164
+ # load input
165
+ audio, sample_rate = librosa.load(cmd.input, sr=None)
166
+ if len(audio.shape) > 1:
167
+ audio = librosa.to_mono(audio)
168
+ hop_size = args.data.block_size * sample_rate / args.data.sampling_rate
169
+
170
+ # get MD5 hash from wav file
171
+ md5_hash = ""
172
+ with open(cmd.input, 'rb') as f:
173
+ data = f.read()
174
+ md5_hash = hashlib.md5(data).hexdigest()
175
+ print("MD5: " + md5_hash)
176
+
177
+ cache_dir_path = os.path.join(os.path.dirname(__file__), "cache")
178
+ cache_file_path = os.path.join(cache_dir_path, f"{cmd.pitch_extractor}_{hop_size}_{cmd.f0_min}_{cmd.f0_max}_{md5_hash}.npy")
179
+
180
+ is_cache_available = os.path.exists(cache_file_path)
181
+ if is_cache_available:
182
+ # f0 cache load
183
+ print('Loading pitch curves for input audio from cache directory...')
184
+ f0 = np.load(cache_file_path, allow_pickle=False)
185
+ else:
186
+ # extract f0
187
+ print('Pitch extractor type: ' + cmd.pitch_extractor)
188
+ pitch_extractor = F0_Extractor(
189
+ cmd.pitch_extractor,
190
+ sample_rate,
191
+ hop_size,
192
+ float(cmd.f0_min),
193
+ float(cmd.f0_max))
194
+ print('Extracting the pitch curve of the input audio...')
195
+ f0 = pitch_extractor.extract(audio, uv_interp = True, device = device)
196
+
197
+ # f0 cache save
198
+ os.makedirs(cache_dir_path, exist_ok=True)
199
+ np.save(cache_file_path, f0, allow_pickle=False)
200
+
201
+ f0 = torch.from_numpy(f0).float().to(device).unsqueeze(-1).unsqueeze(0)
202
+
203
+ # key change
204
+ f0 = f0 * 2 ** (float(cmd.key) / 12)
205
+
206
+ # extract volume
207
+ print('Extracting the volume envelope of the input audio...')
208
+ volume_extractor = Volume_Extractor(hop_size)
209
+ volume = volume_extractor.extract(audio)
210
+ mask = (volume > 10 ** (float(cmd.threhold) / 20)).astype('float')
211
+ mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
212
+ mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)])
213
+ mask = torch.from_numpy(mask).float().to(device).unsqueeze(-1).unsqueeze(0)
214
+ mask = upsample(mask, args.data.block_size).squeeze(-1)
215
+ volume = torch.from_numpy(volume).float().to(device).unsqueeze(-1).unsqueeze(0)
216
+
217
+ # load units encoder
218
+ if args.data.encoder == 'cnhubertsoftfish':
219
+ cnhubertsoft_gate = args.data.cnhubertsoft_gate
220
+ else:
221
+ cnhubertsoft_gate = 10
222
+ units_encoder = Units_Encoder(
223
+ args.data.encoder,
224
+ args.data.encoder_ckpt,
225
+ args.data.encoder_sample_rate,
226
+ args.data.encoder_hop_size,
227
+ cnhubertsoft_gate=cnhubertsoft_gate,
228
+ device = device)
229
+
230
+ # load enhancer
231
+ if cmd.enhance == 'true':
232
+ print('Enhancer type: ' + args.enhancer.type)
233
+ enhancer = Enhancer(args.enhancer.type, args.enhancer.ckpt, device=device)
234
+ else:
235
+ print('Enhancer type: none (using raw output of ddsp)')
236
+
237
+ # speaker id or mix-speaker dictionary
238
+ spk_mix_dict = literal_eval(cmd.spk_mix_dict)
239
+ if spk_mix_dict is not None:
240
+ print('Mix-speaker mode')
241
+ else:
242
+ print('Speaker ID: '+ str(int(cmd.spk_id)))
243
+ spk_id = torch.LongTensor(np.array([[int(cmd.spk_id)]])).to(device)
244
+
245
+ # forward and save the output
246
+ result = np.zeros(0)
247
+ current_length = 0
248
+ segments = split(audio, sample_rate, hop_size)
249
+ print('Cut the input audio into ' + str(len(segments)) + ' slices')
250
+ with torch.no_grad():
251
+ for segment in tqdm(segments):
252
+ start_frame = segment[0]
253
+ seg_input = torch.from_numpy(segment[1]).float().unsqueeze(0).to(device)
254
+ seg_units = units_encoder.encode(seg_input, sample_rate, hop_size)
255
+
256
+ seg_f0 = f0[:, start_frame : start_frame + seg_units.size(1), :]
257
+ seg_volume = volume[:, start_frame : start_frame + seg_units.size(1), :]
258
+
259
+ seg_output, _, (s_h, s_n) = model(seg_units, seg_f0, seg_volume, spk_id = spk_id, spk_mix_dict = spk_mix_dict)
260
+ seg_output *= mask[:, start_frame * args.data.block_size : (start_frame + seg_units.size(1)) * args.data.block_size]
261
+
262
+ if cmd.enhance == 'true':
263
+ seg_output, output_sample_rate = enhancer.enhance(
264
+ seg_output,
265
+ args.data.sampling_rate,
266
+ seg_f0,
267
+ args.data.block_size,
268
+ adaptive_key = cmd.enhancer_adaptive_key)
269
+ else:
270
+ output_sample_rate = args.data.sampling_rate
271
+
272
+ seg_output = seg_output.squeeze().cpu().numpy()
273
+
274
+ silent_length = round(start_frame * args.data.block_size * output_sample_rate / args.data.sampling_rate) - current_length
275
+ if silent_length >= 0:
276
+ result = np.append(result, np.zeros(silent_length))
277
+ result = np.append(result, seg_output)
278
+ else:
279
+ result = cross_fade(result, seg_output, current_length + silent_length)
280
+ current_length = current_length + silent_length + len(seg_output)
281
+ sf.write(cmd.output, result, output_sample_rate)
282
+
DDSP-SVC/main_diff.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ import argparse
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import pyworld as pw
8
+ import parselmouth
9
+ import hashlib
10
+ from ast import literal_eval
11
+ from slicer import Slicer
12
+ from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
13
+ from ddsp.core import upsample
14
+ from diffusion.unit2mel import load_model_vocoder
15
+ from tqdm import tqdm
16
+
17
+ def check_args(ddsp_args, diff_args):
18
+ if ddsp_args.data.sampling_rate != diff_args.data.sampling_rate:
19
+ print("Unmatch data.sampling_rate!")
20
+ return False
21
+ if ddsp_args.data.block_size != diff_args.data.block_size:
22
+ print("Unmatch data.block_size!")
23
+ return False
24
+ if ddsp_args.data.encoder != diff_args.data.encoder:
25
+ print("Unmatch data.encoder!")
26
+ return False
27
+ return True
28
+
29
+ def parse_args(args=None, namespace=None):
30
+ """Parse command-line arguments."""
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument(
33
+ "-diff",
34
+ "--diff_ckpt",
35
+ type=str,
36
+ required=True,
37
+ help="path to the diffusion model checkpoint",
38
+ )
39
+ parser.add_argument(
40
+ "-ddsp",
41
+ "--ddsp_ckpt",
42
+ type=str,
43
+ required=False,
44
+ default="None",
45
+ help="path to the DDSP model checkpoint (for shallow diffusion)",
46
+ )
47
+ parser.add_argument(
48
+ "-d",
49
+ "--device",
50
+ type=str,
51
+ default=None,
52
+ required=False,
53
+ help="cpu or cuda, auto if not set")
54
+ parser.add_argument(
55
+ "-i",
56
+ "--input",
57
+ type=str,
58
+ required=True,
59
+ help="path to the input audio file",
60
+ )
61
+ parser.add_argument(
62
+ "-o",
63
+ "--output",
64
+ type=str,
65
+ required=True,
66
+ help="path to the output audio file",
67
+ )
68
+ parser.add_argument(
69
+ "-id",
70
+ "--spk_id",
71
+ type=str,
72
+ required=False,
73
+ default=1,
74
+ help="speaker id (for multi-speaker model) | default: 1",
75
+ )
76
+ parser.add_argument(
77
+ "-mix",
78
+ "--spk_mix_dict",
79
+ type=str,
80
+ required=False,
81
+ default="None",
82
+ help="mix-speaker dictionary (for multi-speaker model) | default: None",
83
+ )
84
+ parser.add_argument(
85
+ "-k",
86
+ "--key",
87
+ type=str,
88
+ required=False,
89
+ default=0,
90
+ help="key changed (number of semitones) | default: 0",
91
+ )
92
+ parser.add_argument(
93
+ "-f",
94
+ "--formant_shift_key",
95
+ type=str,
96
+ required=False,
97
+ default=0,
98
+ help="formant changed (number of semitones) , only for pitch-augmented model| default: 0",
99
+ )
100
+ parser.add_argument(
101
+ "-pe",
102
+ "--pitch_extractor",
103
+ type=str,
104
+ required=False,
105
+ default='crepe',
106
+ help="pitch extrator type: parselmouth, dio, harvest, crepe (default)",
107
+ )
108
+ parser.add_argument(
109
+ "-fmin",
110
+ "--f0_min",
111
+ type=str,
112
+ required=False,
113
+ default=50,
114
+ help="min f0 (Hz) | default: 50",
115
+ )
116
+ parser.add_argument(
117
+ "-fmax",
118
+ "--f0_max",
119
+ type=str,
120
+ required=False,
121
+ default=1100,
122
+ help="max f0 (Hz) | default: 1100",
123
+ )
124
+ parser.add_argument(
125
+ "-th",
126
+ "--threhold",
127
+ type=str,
128
+ required=False,
129
+ default=-60,
130
+ help="response threhold (dB) | default: -60",
131
+ )
132
+ parser.add_argument(
133
+ "-diffid",
134
+ "--diff_spk_id",
135
+ type=str,
136
+ required=False,
137
+ default='auto',
138
+ help="diffusion speaker id (for multi-speaker model) | default: auto",
139
+ )
140
+ parser.add_argument(
141
+ "-speedup",
142
+ "--speedup",
143
+ type=str,
144
+ required=False,
145
+ default='auto',
146
+ help="speed up | default: auto",
147
+ )
148
+ parser.add_argument(
149
+ "-method",
150
+ "--method",
151
+ type=str,
152
+ required=False,
153
+ default='auto',
154
+ help="pndm or dpm-solver | default: auto",
155
+ )
156
+ parser.add_argument(
157
+ "-kstep",
158
+ "--k_step",
159
+ type=str,
160
+ required=False,
161
+ default=None,
162
+ help="shallow diffusion steps | default: None",
163
+ )
164
+ return parser.parse_args(args=args, namespace=namespace)
165
+
166
+
167
+ def split(audio, sample_rate, hop_size, db_thresh = -40, min_len = 5000):
168
+ slicer = Slicer(
169
+ sr=sample_rate,
170
+ threshold=db_thresh,
171
+ min_length=min_len)
172
+ chunks = dict(slicer.slice(audio))
173
+ result = []
174
+ for k, v in chunks.items():
175
+ tag = v["split_time"].split(",")
176
+ if tag[0] != tag[1]:
177
+ start_frame = int(int(tag[0]) // hop_size)
178
+ end_frame = int(int(tag[1]) // hop_size)
179
+ if end_frame > start_frame:
180
+ result.append((
181
+ start_frame,
182
+ audio[int(start_frame * hop_size) : int(end_frame * hop_size)]))
183
+ return result
184
+
185
+
186
+ def cross_fade(a: np.ndarray, b: np.ndarray, idx: int):
187
+ result = np.zeros(idx + b.shape[0])
188
+ fade_len = a.shape[0] - idx
189
+ np.copyto(dst=result[:idx], src=a[:idx])
190
+ k = np.linspace(0, 1.0, num=fade_len, endpoint=True)
191
+ result[idx: a.shape[0]] = (1 - k) * a[idx:] + k * b[: fade_len]
192
+ np.copyto(dst=result[a.shape[0]:], src=b[fade_len:])
193
+ return result
194
+
195
+
196
+ if __name__ == '__main__':
197
+ # parse commands
198
+ cmd = parse_args()
199
+
200
+ #device = 'cpu'
201
+ device = cmd.device
202
+ if device is None:
203
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
204
+
205
+ # load diffusion model
206
+ model, vocoder, args = load_model_vocoder(cmd.diff_ckpt, device=device)
207
+
208
+ # load input
209
+ audio, sample_rate = librosa.load(cmd.input, sr=None)
210
+ if len(audio.shape) > 1:
211
+ audio = librosa.to_mono(audio)
212
+ hop_size = args.data.block_size * sample_rate / args.data.sampling_rate
213
+
214
+ # get MD5 hash from wav file
215
+ md5_hash = ""
216
+ with open(cmd.input, 'rb') as f:
217
+ data = f.read()
218
+ md5_hash = hashlib.md5(data).hexdigest()
219
+ print("MD5: " + md5_hash)
220
+
221
+ cache_dir_path = os.path.join(os.path.dirname(__file__), "cache")
222
+ cache_file_path = os.path.join(cache_dir_path, f"{cmd.pitch_extractor}_{hop_size}_{cmd.f0_min}_{cmd.f0_max}_{md5_hash}.npy")
223
+
224
+ is_cache_available = os.path.exists(cache_file_path)
225
+ if is_cache_available:
226
+ # f0 cache load
227
+ print('Loading pitch curves for input audio from cache directory...')
228
+ f0 = np.load(cache_file_path, allow_pickle=False)
229
+ else:
230
+ # extract f0
231
+ print('Pitch extractor type: ' + cmd.pitch_extractor)
232
+ pitch_extractor = F0_Extractor(
233
+ cmd.pitch_extractor,
234
+ sample_rate,
235
+ hop_size,
236
+ float(cmd.f0_min),
237
+ float(cmd.f0_max))
238
+ print('Extracting the pitch curve of the input audio...')
239
+ f0 = pitch_extractor.extract(audio, uv_interp = True, device = device)
240
+
241
+ # f0 cache save
242
+ os.makedirs(cache_dir_path, exist_ok=True)
243
+ np.save(cache_file_path, f0, allow_pickle=False)
244
+
245
+ f0 = torch.from_numpy(f0).float().to(device).unsqueeze(-1).unsqueeze(0)
246
+
247
+ # key change
248
+ f0 = f0 * 2 ** (float(cmd.key) / 12)
249
+
250
+ # formant change
251
+ formant_shift_key = torch.LongTensor(np.array([[float(cmd.formant_shift_key)]])).to(device)
252
+
253
+ # extract volume
254
+ print('Extracting the volume envelope of the input audio...')
255
+ volume_extractor = Volume_Extractor(hop_size)
256
+ volume = volume_extractor.extract(audio)
257
+ mask = (volume > 10 ** (float(cmd.threhold) / 20)).astype('float')
258
+ mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
259
+ mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)])
260
+ mask = torch.from_numpy(mask).float().to(device).unsqueeze(-1).unsqueeze(0)
261
+ mask = upsample(mask, args.data.block_size).squeeze(-1)
262
+ volume = torch.from_numpy(volume).float().to(device).unsqueeze(-1).unsqueeze(0)
263
+
264
+ # load units encoder
265
+ if args.data.encoder == 'cnhubertsoftfish':
266
+ cnhubertsoft_gate = args.data.cnhubertsoft_gate
267
+ else:
268
+ cnhubertsoft_gate = 10
269
+ units_encoder = Units_Encoder(
270
+ args.data.encoder,
271
+ args.data.encoder_ckpt,
272
+ args.data.encoder_sample_rate,
273
+ args.data.encoder_hop_size,
274
+ cnhubertsoft_gate=cnhubertsoft_gate,
275
+ device = device)
276
+
277
+ # speaker id or mix-speaker dictionary
278
+ spk_mix_dict = literal_eval(cmd.spk_mix_dict)
279
+ spk_id = torch.LongTensor(np.array([[int(cmd.spk_id)]])).to(device)
280
+ if cmd.diff_spk_id == 'auto':
281
+ diff_spk_id = spk_id
282
+ else:
283
+ diff_spk_id = torch.LongTensor(np.array([[int(cmd.diff_spk_id)]])).to(device)
284
+ if spk_mix_dict is not None:
285
+ print('Mix-speaker mode')
286
+ else:
287
+ print('DDSP Speaker ID: '+ str(int(cmd.spk_id)))
288
+ print('Diffusion Speaker ID: '+ str(cmd.diff_spk_id))
289
+
290
+ # speed up
291
+ if cmd.speedup == 'auto':
292
+ infer_speedup = args.infer.speedup
293
+ else:
294
+ infer_speedup = int(cmd.speedup)
295
+ if cmd.method == 'auto':
296
+ method = args.infer.method
297
+ else:
298
+ method = cmd.method
299
+ if infer_speedup > 1:
300
+ print('Sampling method: '+ method)
301
+ print('Speed up: '+ str(infer_speedup))
302
+ else:
303
+ print('Sampling method: DDPM')
304
+
305
+ ddsp = None
306
+ input_mel = None
307
+ k_step = None
308
+ if cmd.k_step is not None:
309
+ k_step = int(cmd.k_step)
310
+ print('Shallow diffusion step: ' + str(k_step))
311
+ if cmd.ddsp_ckpt != "None":
312
+ # load ddsp model
313
+ ddsp, ddsp_args = load_model(cmd.ddsp_ckpt, device=device)
314
+ if not check_args(ddsp_args, args):
315
+ print("Cannot use this DDSP model for shallow diffusion, gaussian diffusion will be used!")
316
+ ddsp = None
317
+ else:
318
+ print('DDSP model is not identified!')
319
+ print('Extracting the mel spectrum of the input audio for shallow diffusion...')
320
+ audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(device)
321
+ input_mel = vocoder.extract(audio_t, sample_rate)
322
+ input_mel = torch.cat((input_mel, input_mel[:,-1:,:]), 1)
323
+ else:
324
+ print('Shallow diffusion step is not identified, gaussian diffusion will be used!')
325
+
326
+ # forward and save the output
327
+ result = np.zeros(0)
328
+ current_length = 0
329
+ segments = split(audio, sample_rate, hop_size)
330
+ print('Cut the input audio into ' + str(len(segments)) + ' slices')
331
+ with torch.no_grad():
332
+ for segment in tqdm(segments):
333
+ start_frame = segment[0]
334
+ seg_input = torch.from_numpy(segment[1]).float().unsqueeze(0).to(device)
335
+ seg_units = units_encoder.encode(seg_input, sample_rate, hop_size)
336
+
337
+ seg_f0 = f0[:, start_frame : start_frame + seg_units.size(1), :]
338
+ seg_volume = volume[:, start_frame : start_frame + seg_units.size(1), :]
339
+ if ddsp is not None:
340
+ seg_ddsp_f0 = 2 ** (-float(cmd.formant_shift_key) / 12) * seg_f0
341
+ seg_ddsp_output, _ , (_, _) = ddsp(seg_units, seg_ddsp_f0, seg_volume, spk_id = spk_id, spk_mix_dict = spk_mix_dict)
342
+ seg_input_mel = vocoder.extract(seg_ddsp_output, args.data.sampling_rate, keyshift=float(cmd.formant_shift_key))
343
+ elif input_mel != None:
344
+ seg_input_mel = input_mel[:, start_frame : start_frame + seg_units.size(1), :]
345
+ else:
346
+ seg_input_mel = None
347
+
348
+ seg_mel = model(
349
+ seg_units,
350
+ seg_f0,
351
+ seg_volume,
352
+ spk_id = diff_spk_id,
353
+ spk_mix_dict = spk_mix_dict,
354
+ aug_shift = formant_shift_key,
355
+ gt_spec=seg_input_mel,
356
+ infer=True,
357
+ infer_speedup=infer_speedup,
358
+ method=method,
359
+ k_step=k_step)
360
+ seg_output = vocoder.infer(seg_mel, seg_f0)
361
+ seg_output *= mask[:, start_frame * args.data.block_size : (start_frame + seg_units.size(1)) * args.data.block_size]
362
+ seg_output = seg_output.squeeze().cpu().numpy()
363
+
364
+ silent_length = round(start_frame * args.data.block_size) - current_length
365
+ if silent_length >= 0:
366
+ result = np.append(result, np.zeros(silent_length))
367
+ result = np.append(result, seg_output)
368
+ else:
369
+ result = cross_fade(result, seg_output, current_length + silent_length)
370
+ current_length = current_length + silent_length + len(seg_output)
371
+ sf.write(cmd.output, result, args.data.sampling_rate)
372
+
DDSP-SVC/nsf_hifigan/env.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+
5
+ class AttrDict(dict):
6
+ def __init__(self, *args, **kwargs):
7
+ super(AttrDict, self).__init__(*args, **kwargs)
8
+ self.__dict__ = self
9
+
10
+
11
+ def build_env(config, config_name, path):
12
+ t_path = os.path.join(path, config_name)
13
+ if config != t_path:
14
+ os.makedirs(path, exist_ok=True)
15
+ shutil.copyfile(config, os.path.join(path, config_name))
DDSP-SVC/nsf_hifigan/models.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from .env import AttrDict
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
9
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
10
+ from .utils import init_weights, get_padding
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+
15
+ def load_model(model_path, device='cuda'):
16
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
17
+ with open(config_file) as f:
18
+ data = f.read()
19
+
20
+ json_config = json.loads(data)
21
+ h = AttrDict(json_config)
22
+
23
+ generator = Generator(h).to(device)
24
+
25
+ cp_dict = torch.load(model_path, map_location=device)
26
+ generator.load_state_dict(cp_dict['generator'])
27
+ generator.eval()
28
+ generator.remove_weight_norm()
29
+ del cp_dict
30
+ return generator, h
31
+
32
+
33
+ class ResBlock1(torch.nn.Module):
34
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
35
+ super(ResBlock1, self).__init__()
36
+ self.h = h
37
+ self.convs1 = nn.ModuleList([
38
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
39
+ padding=get_padding(kernel_size, dilation[0]))),
40
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
41
+ padding=get_padding(kernel_size, dilation[1]))),
42
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
43
+ padding=get_padding(kernel_size, dilation[2])))
44
+ ])
45
+ self.convs1.apply(init_weights)
46
+
47
+ self.convs2 = nn.ModuleList([
48
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
49
+ padding=get_padding(kernel_size, 1))),
50
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
51
+ padding=get_padding(kernel_size, 1))),
52
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
53
+ padding=get_padding(kernel_size, 1)))
54
+ ])
55
+ self.convs2.apply(init_weights)
56
+
57
+ def forward(self, x):
58
+ for c1, c2 in zip(self.convs1, self.convs2):
59
+ xt = F.leaky_relu(x, LRELU_SLOPE)
60
+ xt = c1(xt)
61
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
62
+ xt = c2(xt)
63
+ x = xt + x
64
+ return x
65
+
66
+ def remove_weight_norm(self):
67
+ for l in self.convs1:
68
+ remove_weight_norm(l)
69
+ for l in self.convs2:
70
+ remove_weight_norm(l)
71
+
72
+
73
+ class ResBlock2(torch.nn.Module):
74
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
75
+ super(ResBlock2, self).__init__()
76
+ self.h = h
77
+ self.convs = nn.ModuleList([
78
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
79
+ padding=get_padding(kernel_size, dilation[0]))),
80
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
81
+ padding=get_padding(kernel_size, dilation[1])))
82
+ ])
83
+ self.convs.apply(init_weights)
84
+
85
+ def forward(self, x):
86
+ for c in self.convs:
87
+ xt = F.leaky_relu(x, LRELU_SLOPE)
88
+ xt = c(xt)
89
+ x = xt + x
90
+ return x
91
+
92
+ def remove_weight_norm(self):
93
+ for l in self.convs:
94
+ remove_weight_norm(l)
95
+
96
+
97
+ class SineGen(torch.nn.Module):
98
+ """ Definition of sine generator
99
+ SineGen(samp_rate, harmonic_num = 0,
100
+ sine_amp = 0.1, noise_std = 0.003,
101
+ voiced_threshold = 0,
102
+ flag_for_pulse=False)
103
+ samp_rate: sampling rate in Hz
104
+ harmonic_num: number of harmonic overtones (default 0)
105
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
106
+ noise_std: std of Gaussian noise (default 0.003)
107
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
108
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
109
+ Note: when flag_for_pulse is True, the first time step of a voiced
110
+ segment is always sin(np.pi) or cos(0)
111
+ """
112
+
113
+ def __init__(self, samp_rate, harmonic_num=0,
114
+ sine_amp=0.1, noise_std=0.003,
115
+ voiced_threshold=0):
116
+ super(SineGen, self).__init__()
117
+ self.sine_amp = sine_amp
118
+ self.noise_std = noise_std
119
+ self.harmonic_num = harmonic_num
120
+ self.dim = self.harmonic_num + 1
121
+ self.sampling_rate = samp_rate
122
+ self.voiced_threshold = voiced_threshold
123
+
124
+ def _f02uv(self, f0):
125
+ # generate uv signal
126
+ uv = torch.ones_like(f0)
127
+ uv = uv * (f0 > self.voiced_threshold)
128
+ return uv
129
+
130
+ @torch.no_grad()
131
+ def forward(self, f0, upp):
132
+ """ sine_tensor, uv = forward(f0)
133
+ input F0: tensor(batchsize=1, length, dim=1)
134
+ f0 for unvoiced steps should be 0
135
+ output sine_tensor: tensor(batchsize=1, length, dim)
136
+ output uv: tensor(batchsize=1, length, 1)
137
+ """
138
+ f0 = f0.unsqueeze(-1)
139
+ fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1)))
140
+ rad_values = (fn / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
141
+ rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device)
142
+ rand_ini[:, 0] = 0
143
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
144
+ is_half = rad_values.dtype is not torch.float32
145
+ tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化
146
+ if is_half:
147
+ tmp_over_one = tmp_over_one.half()
148
+ else:
149
+ tmp_over_one = tmp_over_one.float()
150
+ tmp_over_one *= upp
151
+ tmp_over_one = F.interpolate(
152
+ tmp_over_one.transpose(2, 1), scale_factor=upp,
153
+ mode='linear', align_corners=True
154
+ ).transpose(2, 1)
155
+ rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
156
+ tmp_over_one %= 1
157
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
158
+ cumsum_shift = torch.zeros_like(rad_values)
159
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
160
+ rad_values = rad_values.double()
161
+ cumsum_shift = cumsum_shift.double()
162
+ sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
163
+ if is_half:
164
+ sine_waves = sine_waves.half()
165
+ else:
166
+ sine_waves = sine_waves.float()
167
+ sine_waves = sine_waves * self.sine_amp
168
+ return sine_waves
169
+
170
+
171
+ class SourceModuleHnNSF(torch.nn.Module):
172
+ """ SourceModule for hn-nsf
173
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
174
+ add_noise_std=0.003, voiced_threshod=0)
175
+ sampling_rate: sampling_rate in Hz
176
+ harmonic_num: number of harmonic above F0 (default: 0)
177
+ sine_amp: amplitude of sine source signal (default: 0.1)
178
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
179
+ note that amplitude of noise in unvoiced is decided
180
+ by sine_amp
181
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
182
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
183
+ F0_sampled (batchsize, length, 1)
184
+ Sine_source (batchsize, length, 1)
185
+ noise_source (batchsize, length 1)
186
+ uv (batchsize, length, 1)
187
+ """
188
+
189
+ def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
190
+ add_noise_std=0.003, voiced_threshod=0):
191
+ super(SourceModuleHnNSF, self).__init__()
192
+
193
+ self.sine_amp = sine_amp
194
+ self.noise_std = add_noise_std
195
+
196
+ # to produce sine waveforms
197
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
198
+ sine_amp, add_noise_std, voiced_threshod)
199
+
200
+ # to merge source harmonics into a single excitation
201
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
202
+ self.l_tanh = torch.nn.Tanh()
203
+
204
+ def forward(self, x, upp):
205
+ sine_wavs = self.l_sin_gen(x, upp)
206
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
207
+ return sine_merge
208
+
209
+
210
+ class Generator(torch.nn.Module):
211
+ def __init__(self, h):
212
+ super(Generator, self).__init__()
213
+ self.h = h
214
+ self.num_kernels = len(h.resblock_kernel_sizes)
215
+ self.num_upsamples = len(h.upsample_rates)
216
+ self.m_source = SourceModuleHnNSF(
217
+ sampling_rate=h.sampling_rate,
218
+ harmonic_num=8
219
+ )
220
+ self.noise_convs = nn.ModuleList()
221
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
222
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
223
+
224
+ self.ups = nn.ModuleList()
225
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
226
+ c_cur = h.upsample_initial_channel // (2 ** (i + 1))
227
+ self.ups.append(weight_norm(
228
+ ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
229
+ k, u, padding=(k - u) // 2)))
230
+ if i + 1 < len(h.upsample_rates): #
231
+ stride_f0 = int(np.prod(h.upsample_rates[i + 1:]))
232
+ self.noise_convs.append(Conv1d(
233
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
234
+ else:
235
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
236
+ self.resblocks = nn.ModuleList()
237
+ ch = h.upsample_initial_channel
238
+ for i in range(len(self.ups)):
239
+ ch //= 2
240
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
241
+ self.resblocks.append(resblock(h, ch, k, d))
242
+
243
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
244
+ self.ups.apply(init_weights)
245
+ self.conv_post.apply(init_weights)
246
+ self.upp = int(np.prod(h.upsample_rates))
247
+
248
+ def forward(self, x, f0):
249
+ har_source = self.m_source(f0, self.upp).transpose(1, 2)
250
+ x = self.conv_pre(x)
251
+ for i in range(self.num_upsamples):
252
+ x = F.leaky_relu(x, LRELU_SLOPE)
253
+ x = self.ups[i](x)
254
+ x_source = self.noise_convs[i](har_source)
255
+ x = x + x_source
256
+ xs = None
257
+ for j in range(self.num_kernels):
258
+ if xs is None:
259
+ xs = self.resblocks[i * self.num_kernels + j](x)
260
+ else:
261
+ xs += self.resblocks[i * self.num_kernels + j](x)
262
+ x = xs / self.num_kernels
263
+ x = F.leaky_relu(x)
264
+ x = self.conv_post(x)
265
+ x = torch.tanh(x)
266
+
267
+ return x
268
+
269
+ def remove_weight_norm(self):
270
+ print('Removing weight norm...')
271
+ for l in self.ups:
272
+ remove_weight_norm(l)
273
+ for l in self.resblocks:
274
+ l.remove_weight_norm()
275
+ remove_weight_norm(self.conv_pre)
276
+ remove_weight_norm(self.conv_post)
277
+
278
+
279
+ class DiscriminatorP(torch.nn.Module):
280
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
281
+ super(DiscriminatorP, self).__init__()
282
+ self.period = period
283
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
284
+ self.convs = nn.ModuleList([
285
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
286
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
287
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
288
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
289
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
290
+ ])
291
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
292
+
293
+ def forward(self, x):
294
+ fmap = []
295
+
296
+ # 1d to 2d
297
+ b, c, t = x.shape
298
+ if t % self.period != 0: # pad first
299
+ n_pad = self.period - (t % self.period)
300
+ x = F.pad(x, (0, n_pad), "reflect")
301
+ t = t + n_pad
302
+ x = x.view(b, c, t // self.period, self.period)
303
+
304
+ for l in self.convs:
305
+ x = l(x)
306
+ x = F.leaky_relu(x, LRELU_SLOPE)
307
+ fmap.append(x)
308
+ x = self.conv_post(x)
309
+ fmap.append(x)
310
+ x = torch.flatten(x, 1, -1)
311
+
312
+ return x, fmap
313
+
314
+
315
+ class MultiPeriodDiscriminator(torch.nn.Module):
316
+ def __init__(self, periods=None):
317
+ super(MultiPeriodDiscriminator, self).__init__()
318
+ self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
319
+ self.discriminators = nn.ModuleList()
320
+ for period in self.periods:
321
+ self.discriminators.append(DiscriminatorP(period))
322
+
323
+ def forward(self, y, y_hat):
324
+ y_d_rs = []
325
+ y_d_gs = []
326
+ fmap_rs = []
327
+ fmap_gs = []
328
+ for i, d in enumerate(self.discriminators):
329
+ y_d_r, fmap_r = d(y)
330
+ y_d_g, fmap_g = d(y_hat)
331
+ y_d_rs.append(y_d_r)
332
+ fmap_rs.append(fmap_r)
333
+ y_d_gs.append(y_d_g)
334
+ fmap_gs.append(fmap_g)
335
+
336
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
337
+
338
+
339
+ class DiscriminatorS(torch.nn.Module):
340
+ def __init__(self, use_spectral_norm=False):
341
+ super(DiscriminatorS, self).__init__()
342
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
343
+ self.convs = nn.ModuleList([
344
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
345
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
346
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
347
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
348
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
349
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
350
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
351
+ ])
352
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
353
+
354
+ def forward(self, x):
355
+ fmap = []
356
+ for l in self.convs:
357
+ x = l(x)
358
+ x = F.leaky_relu(x, LRELU_SLOPE)
359
+ fmap.append(x)
360
+ x = self.conv_post(x)
361
+ fmap.append(x)
362
+ x = torch.flatten(x, 1, -1)
363
+
364
+ return x, fmap
365
+
366
+
367
+ class MultiScaleDiscriminator(torch.nn.Module):
368
+ def __init__(self):
369
+ super(MultiScaleDiscriminator, self).__init__()
370
+ self.discriminators = nn.ModuleList([
371
+ DiscriminatorS(use_spectral_norm=True),
372
+ DiscriminatorS(),
373
+ DiscriminatorS(),
374
+ ])
375
+ self.meanpools = nn.ModuleList([
376
+ AvgPool1d(4, 2, padding=2),
377
+ AvgPool1d(4, 2, padding=2)
378
+ ])
379
+
380
+ def forward(self, y, y_hat):
381
+ y_d_rs = []
382
+ y_d_gs = []
383
+ fmap_rs = []
384
+ fmap_gs = []
385
+ for i, d in enumerate(self.discriminators):
386
+ if i != 0:
387
+ y = self.meanpools[i - 1](y)
388
+ y_hat = self.meanpools[i - 1](y_hat)
389
+ y_d_r, fmap_r = d(y)
390
+ y_d_g, fmap_g = d(y_hat)
391
+ y_d_rs.append(y_d_r)
392
+ fmap_rs.append(fmap_r)
393
+ y_d_gs.append(y_d_g)
394
+ fmap_gs.append(fmap_g)
395
+
396
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
397
+
398
+
399
+ def feature_loss(fmap_r, fmap_g):
400
+ loss = 0
401
+ for dr, dg in zip(fmap_r, fmap_g):
402
+ for rl, gl in zip(dr, dg):
403
+ loss += torch.mean(torch.abs(rl - gl))
404
+
405
+ return loss * 2
406
+
407
+
408
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
409
+ loss = 0
410
+ r_losses = []
411
+ g_losses = []
412
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
413
+ r_loss = torch.mean((1 - dr) ** 2)
414
+ g_loss = torch.mean(dg ** 2)
415
+ loss += (r_loss + g_loss)
416
+ r_losses.append(r_loss.item())
417
+ g_losses.append(g_loss.item())
418
+
419
+ return loss, r_losses, g_losses
420
+
421
+
422
+ def generator_loss(disc_outputs):
423
+ loss = 0
424
+ gen_losses = []
425
+ for dg in disc_outputs:
426
+ l = torch.mean((1 - dg) ** 2)
427
+ gen_losses.append(l)
428
+ loss += l
429
+
430
+ return loss, gen_losses
DDSP-SVC/nsf_hifigan/nvSTFT.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ os.environ["LRU_CACHE_CAPACITY"] = "3"
4
+ import random
5
+ import torch
6
+ import torch.utils.data
7
+ import numpy as np
8
+ import librosa
9
+ from librosa.util import normalize
10
+ from librosa.filters import mel as librosa_mel_fn
11
+ from scipy.io.wavfile import read
12
+ import soundfile as sf
13
+ import torch.nn.functional as F
14
+
15
+ def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
16
+ sampling_rate = None
17
+ try:
18
+ data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
19
+ except Exception as ex:
20
+ print(f"'{full_path}' failed to load.\nException:")
21
+ print(ex)
22
+ if return_empty_on_exception:
23
+ return [], sampling_rate or target_sr or 48000
24
+ else:
25
+ raise Exception(ex)
26
+
27
+ if len(data.shape) > 1:
28
+ data = data[:, 0]
29
+ assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
30
+
31
+ if np.issubdtype(data.dtype, np.integer): # if audio data is type int
32
+ max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
33
+ else: # if audio data is type fp32
34
+ max_mag = max(np.amax(data), -np.amin(data))
35
+ max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
36
+
37
+ data = torch.FloatTensor(data.astype(np.float32))/max_mag
38
+
39
+ if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
40
+ return [], sampling_rate or target_sr or 48000
41
+ if target_sr is not None and sampling_rate != target_sr:
42
+ data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
43
+ sampling_rate = target_sr
44
+
45
+ return data, sampling_rate
46
+
47
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
48
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
49
+
50
+ def dynamic_range_decompression(x, C=1):
51
+ return np.exp(x) / C
52
+
53
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
54
+ return torch.log(torch.clamp(x, min=clip_val) * C)
55
+
56
+ def dynamic_range_decompression_torch(x, C=1):
57
+ return torch.exp(x) / C
58
+
59
+ class STFT():
60
+ def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
61
+ self.target_sr = sr
62
+
63
+ self.n_mels = n_mels
64
+ self.n_fft = n_fft
65
+ self.win_size = win_size
66
+ self.hop_length = hop_length
67
+ self.fmin = fmin
68
+ self.fmax = fmax
69
+ self.clip_val = clip_val
70
+ self.mel_basis = {}
71
+ self.hann_window = {}
72
+
73
+ def get_mel(self, y, keyshift=0, speed=1, center=False):
74
+ sampling_rate = self.target_sr
75
+ n_mels = self.n_mels
76
+ n_fft = self.n_fft
77
+ win_size = self.win_size
78
+ hop_length = self.hop_length
79
+ fmin = self.fmin
80
+ fmax = self.fmax
81
+ clip_val = self.clip_val
82
+
83
+ factor = 2 ** (keyshift / 12)
84
+ n_fft_new = int(np.round(n_fft * factor))
85
+ win_size_new = int(np.round(win_size * factor))
86
+ hop_length_new = int(np.round(hop_length * speed))
87
+
88
+ if torch.min(y) < -1.:
89
+ print('min value is ', torch.min(y))
90
+ if torch.max(y) > 1.:
91
+ print('max value is ', torch.max(y))
92
+
93
+ mel_basis_key = str(fmax)+'_'+str(y.device)
94
+ if mel_basis_key not in self.mel_basis:
95
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
96
+ self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
97
+
98
+ keyshift_key = str(keyshift)+'_'+str(y.device)
99
+ if keyshift_key not in self.hann_window:
100
+ self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
101
+
102
+ pad_left = (win_size_new - hop_length_new) //2
103
+ pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left)
104
+ if pad_right < y.size(-1):
105
+ mode = 'reflect'
106
+ else:
107
+ mode = 'constant'
108
+ y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode)
109
+ y = y.squeeze(1)
110
+
111
+ spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key],
112
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
113
+ spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
114
+ if keyshift != 0:
115
+ size = n_fft // 2 + 1
116
+ resize = spec.size(1)
117
+ if resize < size:
118
+ spec = F.pad(spec, (0, 0, 0, size-resize))
119
+ spec = spec[:, :size, :] * win_size / win_size_new
120
+ spec = torch.matmul(self.mel_basis[mel_basis_key], spec)
121
+ spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
122
+ return spec
123
+
124
+ def __call__(self, audiopath):
125
+ audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
126
+ spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
127
+ return spect
128
+
129
+ stft = STFT()
DDSP-SVC/nsf_hifigan/utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import matplotlib
4
+ import torch
5
+ from torch.nn.utils import weight_norm
6
+ matplotlib.use("Agg")
7
+ import matplotlib.pylab as plt
8
+
9
+
10
+ def plot_spectrogram(spectrogram):
11
+ fig, ax = plt.subplots(figsize=(10, 2))
12
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13
+ interpolation='none')
14
+ plt.colorbar(im, ax=ax)
15
+
16
+ fig.canvas.draw()
17
+ plt.close()
18
+
19
+ return fig
20
+
21
+
22
+ def init_weights(m, mean=0.0, std=0.01):
23
+ classname = m.__class__.__name__
24
+ if classname.find("Conv") != -1:
25
+ m.weight.data.normal_(mean, std)
26
+
27
+
28
+ def apply_weight_norm(m):
29
+ classname = m.__class__.__name__
30
+ if classname.find("Conv") != -1:
31
+ weight_norm(m)
32
+
33
+
34
+ def get_padding(kernel_size, dilation=1):
35
+ return int((kernel_size*dilation - dilation)/2)
36
+
37
+
38
+ def load_checkpoint(filepath, device):
39
+ assert os.path.isfile(filepath)
40
+ print("Loading '{}'".format(filepath))
41
+ checkpoint_dict = torch.load(filepath, map_location=device)
42
+ print("Complete.")
43
+ return checkpoint_dict
44
+
45
+
46
+ def save_checkpoint(filepath, obj):
47
+ print("Saving checkpoint to {}".format(filepath))
48
+ torch.save(obj, filepath)
49
+ print("Complete.")
50
+
51
+
52
+ def del_old_checkpoints(cp_dir, prefix, n_models=2):
53
+ pattern = os.path.join(cp_dir, prefix + '????????')
54
+ cp_list = glob.glob(pattern) # get checkpoint paths
55
+ cp_list = sorted(cp_list)# sort by iter
56
+ if len(cp_list) > n_models: # if more than n_models models are found
57
+ for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
58
+ open(cp, 'w').close()# empty file contents
59
+ os.unlink(cp)# delete file (move to trash when using Colab)
60
+
61
+
62
+ def scan_checkpoint(cp_dir, prefix):
63
+ pattern = os.path.join(cp_dir, prefix + '????????')
64
+ cp_list = glob.glob(pattern)
65
+ if len(cp_list) == 0:
66
+ return None
67
+ return sorted(cp_list)[-1]
68
+
DDSP-SVC/preprocess.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import random
4
+ import librosa
5
+ import torch
6
+ import pyworld as pw
7
+ import parselmouth
8
+ import argparse
9
+ import shutil
10
+ from logger import utils
11
+ from tqdm import tqdm
12
+ from ddsp.vocoder import F0_Extractor, Volume_Extractor, Units_Encoder
13
+ from diffusion.vocoder import Vocoder
14
+ from logger.utils import traverse_dir
15
+ import concurrent.futures
16
+
17
+ def parse_args(args=None, namespace=None):
18
+ """Parse command-line arguments."""
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument(
21
+ "-c",
22
+ "--config",
23
+ type=str,
24
+ required=True,
25
+ help="path to the config file")
26
+ parser.add_argument(
27
+ "-d",
28
+ "--device",
29
+ type=str,
30
+ default=None,
31
+ required=False,
32
+ help="cpu or cuda, auto if not set")
33
+ return parser.parse_args(args=args, namespace=namespace)
34
+
35
+ def preprocess(path, f0_extractor, volume_extractor, mel_extractor, units_encoder, sample_rate, hop_size, device = 'cuda', use_pitch_aug = False):
36
+
37
+ path_srcdir = os.path.join(path, 'audio')
38
+ path_unitsdir = os.path.join(path, 'units')
39
+ path_f0dir = os.path.join(path, 'f0')
40
+ path_volumedir = os.path.join(path, 'volume')
41
+ path_augvoldir = os.path.join(path, 'aug_vol')
42
+ path_meldir = os.path.join(path, 'mel')
43
+ path_augmeldir = os.path.join(path, 'aug_mel')
44
+ path_skipdir = os.path.join(path, 'skip')
45
+
46
+ # list files
47
+ filelist = traverse_dir(
48
+ path_srcdir,
49
+ extension='wav',
50
+ is_pure=True,
51
+ is_sort=True,
52
+ is_ext=True)
53
+
54
+ # pitch augmentation dictionary
55
+ pitch_aug_dict = {}
56
+
57
+ # run
58
+ def process(file):
59
+ ext = file.split('.')[-1]
60
+ binfile = file[:-(len(ext)+1)]+'.npy'
61
+ path_srcfile = os.path.join(path_srcdir, file)
62
+ path_unitsfile = os.path.join(path_unitsdir, binfile)
63
+ path_f0file = os.path.join(path_f0dir, binfile)
64
+ path_volumefile = os.path.join(path_volumedir, binfile)
65
+ path_augvolfile = os.path.join(path_augvoldir, binfile)
66
+ path_melfile = os.path.join(path_meldir, binfile)
67
+ path_augmelfile = os.path.join(path_augmeldir, binfile)
68
+ path_skipfile = os.path.join(path_skipdir, file)
69
+
70
+ # load audio
71
+ audio, _ = librosa.load(path_srcfile, sr=sample_rate)
72
+ if len(audio.shape) > 1:
73
+ audio = librosa.to_mono(audio)
74
+ audio_t = torch.from_numpy(audio).float().to(device)
75
+ audio_t = audio_t.unsqueeze(0)
76
+
77
+ # extract volume
78
+ volume = volume_extractor.extract(audio)
79
+
80
+ # extract mel and volume augmentaion
81
+ if mel_extractor is not None:
82
+ mel_t = mel_extractor.extract(audio_t, sample_rate)
83
+ mel = mel_t.squeeze().to('cpu').numpy()
84
+
85
+ max_amp = float(torch.max(torch.abs(audio_t))) + 1e-5
86
+ max_shift = min(1, np.log10(1/max_amp))
87
+ log10_vol_shift = random.uniform(-1, max_shift)
88
+ if use_pitch_aug:
89
+ keyshift = random.uniform(-5, 5)
90
+ else:
91
+ keyshift = 0
92
+
93
+ aug_mel_t = mel_extractor.extract(audio_t * (10 ** log10_vol_shift), sample_rate, keyshift = keyshift)
94
+ aug_mel = aug_mel_t.squeeze().to('cpu').numpy()
95
+ aug_vol = volume_extractor.extract(audio * (10 ** log10_vol_shift))
96
+
97
+ # units encode
98
+ units_t = units_encoder.encode(audio_t, sample_rate, hop_size)
99
+ units = units_t.squeeze().to('cpu').numpy()
100
+
101
+ # extract f0
102
+ f0 = f0_extractor.extract(audio, uv_interp = False)
103
+
104
+ uv = f0 == 0
105
+ if len(f0[~uv]) > 0:
106
+ # interpolate the unvoiced f0
107
+ f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
108
+
109
+ # save npy
110
+ os.makedirs(os.path.dirname(path_unitsfile), exist_ok=True)
111
+ np.save(path_unitsfile, units)
112
+ os.makedirs(os.path.dirname(path_f0file), exist_ok=True)
113
+ np.save(path_f0file, f0)
114
+ os.makedirs(os.path.dirname(path_volumefile), exist_ok=True)
115
+ np.save(path_volumefile, volume)
116
+ if mel_extractor is not None:
117
+ pitch_aug_dict[file[:-(len(ext)+1)]] = keyshift
118
+ os.makedirs(os.path.dirname(path_melfile), exist_ok=True)
119
+ np.save(path_melfile, mel)
120
+ os.makedirs(os.path.dirname(path_augmelfile), exist_ok=True)
121
+ np.save(path_augmelfile, aug_mel)
122
+ os.makedirs(os.path.dirname(path_augvolfile), exist_ok=True)
123
+ np.save(path_augvolfile, aug_vol)
124
+ else:
125
+ print('\n[Error] F0 extraction failed: ' + path_srcfile)
126
+ os.makedirs(os.path.dirname(path_skipfile), exist_ok=True)
127
+ shutil.move(path_srcfile, os.path.dirname(path_skipfile))
128
+ print('This file has been moved to ' + path_skipfile)
129
+ print('Preprocess the audio clips in :', path_srcdir)
130
+
131
+ # single process
132
+ for file in tqdm(filelist, total=len(filelist)):
133
+ process(file)
134
+
135
+ if mel_extractor is not None:
136
+ path_pitchaugdict = os.path.join(path, 'pitch_aug_dict.npy')
137
+ np.save(path_pitchaugdict, pitch_aug_dict)
138
+ # multi-process (have bugs)
139
+ '''
140
+ with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor:
141
+ list(tqdm(executor.map(process, filelist), total=len(filelist)))
142
+ '''
143
+
144
+ if __name__ == '__main__':
145
+ # parse commands
146
+ cmd = parse_args()
147
+
148
+ device = cmd.device
149
+ if device is None:
150
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
151
+
152
+ # load config
153
+ args = utils.load_config(cmd.config)
154
+ sample_rate = args.data.sampling_rate
155
+ hop_size = args.data.block_size
156
+
157
+ # initialize f0 extractor
158
+ f0_extractor = F0_Extractor(
159
+ args.data.f0_extractor,
160
+ args.data.sampling_rate,
161
+ args.data.block_size,
162
+ args.data.f0_min,
163
+ args.data.f0_max)
164
+
165
+ # initialize volume extractor
166
+ volume_extractor = Volume_Extractor(args.data.block_size)
167
+
168
+ # initialize mel extractor
169
+ mel_extractor = None
170
+ use_pitch_aug = False
171
+ if args.model.type == 'Diffusion':
172
+ mel_extractor = Vocoder(args.vocoder.type, args.vocoder.ckpt, device = device)
173
+ if mel_extractor.vocoder_sample_rate != sample_rate or mel_extractor.vocoder_hop_size != hop_size:
174
+ mel_extractor = None
175
+ print('Unmatch vocoder parameters, mel extraction is ignored!')
176
+ elif args.model.use_pitch_aug:
177
+ use_pitch_aug = True
178
+
179
+ # initialize units encoder
180
+ if args.data.encoder == 'cnhubertsoftfish':
181
+ cnhubertsoft_gate = args.data.cnhubertsoft_gate
182
+ else:
183
+ cnhubertsoft_gate = 10
184
+ units_encoder = Units_Encoder(
185
+ args.data.encoder,
186
+ args.data.encoder_ckpt,
187
+ args.data.encoder_sample_rate,
188
+ args.data.encoder_hop_size,
189
+ cnhubertsoft_gate=cnhubertsoft_gate,
190
+ device = device)
191
+
192
+ # preprocess training set
193
+ preprocess(args.data.train_path, f0_extractor, volume_extractor, mel_extractor, units_encoder, sample_rate, hop_size, device = device, use_pitch_aug = use_pitch_aug)
194
+
195
+ # preprocess validation set
196
+ preprocess(args.data.valid_path, f0_extractor, volume_extractor, mel_extractor, units_encoder, sample_rate, hop_size, device = device, use_pitch_aug = False)
197
+
DDSP-SVC/pretrain/hubert/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
DDSP-SVC/pretrain/nsf_hifigan/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
DDSP-SVC/requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ fairseq
3
+ flask
4
+ flask_cors
5
+ gin
6
+ gin_config
7
+ librosa
8
+ local_attention
9
+ matplotlib
10
+ numpy
11
+ praat-parselmouth
12
+ pyworld
13
+ PyYAML
14
+ resampy
15
+ scikit_learn
16
+ scipy
17
+ SoundFile
18
+ tensorboard
19
+ torchcrepe
20
+ tqdm
21
+ transformers
22
+ wave
23
+ pysimplegui
24
+ sounddevice
25
+ gradio
DDSP-SVC/slicer.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ import torchaudio
4
+
5
+
6
+ class Slicer:
7
+ def __init__(self,
8
+ sr: int,
9
+ threshold: float = -40.,
10
+ min_length: int = 5000,
11
+ min_interval: int = 300,
12
+ hop_size: int = 20,
13
+ max_sil_kept: int = 5000):
14
+ if not min_length >= min_interval >= hop_size:
15
+ raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
16
+ if not max_sil_kept >= hop_size:
17
+ raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
18
+ min_interval = sr * min_interval / 1000
19
+ self.threshold = 10 ** (threshold / 20.)
20
+ self.hop_size = round(sr * hop_size / 1000)
21
+ self.win_size = min(round(min_interval), 4 * self.hop_size)
22
+ self.min_length = round(sr * min_length / 1000 / self.hop_size)
23
+ self.min_interval = round(min_interval / self.hop_size)
24
+ self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
25
+
26
+ def _apply_slice(self, waveform, begin, end):
27
+ if len(waveform.shape) > 1:
28
+ return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
29
+ else:
30
+ return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
31
+
32
+ # @timeit
33
+ def slice(self, waveform):
34
+ if len(waveform.shape) > 1:
35
+ samples = librosa.to_mono(waveform)
36
+ else:
37
+ samples = waveform
38
+ if samples.shape[0] <= self.min_length:
39
+ return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
40
+ rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
41
+ sil_tags = []
42
+ silence_start = None
43
+ clip_start = 0
44
+ for i, rms in enumerate(rms_list):
45
+ # Keep looping while frame is silent.
46
+ if rms < self.threshold:
47
+ # Record start of silent frames.
48
+ if silence_start is None:
49
+ silence_start = i
50
+ continue
51
+ # Keep looping while frame is not silent and silence start has not been recorded.
52
+ if silence_start is None:
53
+ continue
54
+ # Clear recorded silence start if interval is not enough or clip is too short
55
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
56
+ need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
57
+ if not is_leading_silence and not need_slice_middle:
58
+ silence_start = None
59
+ continue
60
+ # Need slicing. Record the range of silent frames to be removed.
61
+ if i - silence_start <= self.max_sil_kept:
62
+ pos = rms_list[silence_start: i + 1].argmin() + silence_start
63
+ if silence_start == 0:
64
+ sil_tags.append((0, pos))
65
+ else:
66
+ sil_tags.append((pos, pos))
67
+ clip_start = pos
68
+ elif i - silence_start <= self.max_sil_kept * 2:
69
+ pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
70
+ pos += i - self.max_sil_kept
71
+ pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
72
+ pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
73
+ if silence_start == 0:
74
+ sil_tags.append((0, pos_r))
75
+ clip_start = pos_r
76
+ else:
77
+ sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
78
+ clip_start = max(pos_r, pos)
79
+ else:
80
+ pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
81
+ pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
82
+ if silence_start == 0:
83
+ sil_tags.append((0, pos_r))
84
+ else:
85
+ sil_tags.append((pos_l, pos_r))
86
+ clip_start = pos_r
87
+ silence_start = None
88
+ # Deal with trailing silence.
89
+ total_frames = rms_list.shape[0]
90
+ if silence_start is not None and total_frames - silence_start >= self.min_interval:
91
+ silence_end = min(total_frames, silence_start + self.max_sil_kept)
92
+ pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
93
+ sil_tags.append((pos, total_frames + 1))
94
+ # Apply and return slices.
95
+ if len(sil_tags) == 0:
96
+ return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
97
+ else:
98
+ chunks = []
99
+ # 第一段静音并非从头开始,补上有声片段
100
+ if sil_tags[0][0]:
101
+ chunks.append(
102
+ {"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
103
+ for i in range(0, len(sil_tags)):
104
+ # 标识有声片段(跳过第一段)
105
+ if i:
106
+ chunks.append({"slice": False,
107
+ "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
108
+ # 标识所有静音片段
109
+ chunks.append({"slice": True,
110
+ "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
111
+ # 最后一段静音并非结尾,补上结尾片段
112
+ if sil_tags[-1][1] * self.hop_size < len(waveform):
113
+ chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
114
+ chunk_dict = {}
115
+ for i in range(len(chunks)):
116
+ chunk_dict[str(i)] = chunks[i]
117
+ return chunk_dict
118
+
119
+
120
+ def cut(audio_path, db_thresh=-30, min_len=5000, flask_mode=False, flask_sr=None):
121
+ if not flask_mode:
122
+ audio, sr = librosa.load(audio_path, sr=None)
123
+ else:
124
+ audio = audio_path
125
+ sr = flask_sr
126
+ slicer = Slicer(
127
+ sr=sr,
128
+ threshold=db_thresh,
129
+ min_length=min_len
130
+ )
131
+ chunks = slicer.slice(audio)
132
+ return chunks
133
+
134
+
135
+ def chunks2audio(audio_path, chunks):
136
+ chunks = dict(chunks)
137
+ audio, sr = torchaudio.load(audio_path)
138
+ if len(audio.shape) == 2 and audio.shape[1] >= 2:
139
+ audio = torch.mean(audio, dim=0).unsqueeze(0)
140
+ audio = audio.cpu().numpy()[0]
141
+ result = []
142
+ for k, v in chunks.items():
143
+ tag = v["split_time"].split(",")
144
+ if tag[0] != tag[1]:
145
+ result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
146
+ return result, sr
DDSP-SVC/solver.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+
6
+ from logger.saver import Saver
7
+ from logger import utils
8
+
9
+ def test(args, model, loss_func, loader_test, saver):
10
+ print(' [*] testing...')
11
+ model.eval()
12
+
13
+ # losses
14
+ test_loss = 0.
15
+ test_loss_rss = 0.
16
+ test_loss_uv = 0.
17
+
18
+ # intialization
19
+ num_batches = len(loader_test)
20
+ rtf_all = []
21
+
22
+ # run
23
+ with torch.no_grad():
24
+ for bidx, data in enumerate(loader_test):
25
+ fn = data['name'][0]
26
+ print('--------')
27
+ print('{}/{} - {}'.format(bidx, num_batches, fn))
28
+
29
+ # unpack data
30
+ for k in data.keys():
31
+ if k != 'name':
32
+ data[k] = data[k].to(args.device)
33
+ print('>>', data['name'][0])
34
+
35
+ # forward
36
+ st_time = time.time()
37
+ signal, _, (s_h, s_n) = model(data['units'], data['f0'], data['volume'], data['spk_id'])
38
+ ed_time = time.time()
39
+
40
+ # crop
41
+ min_len = np.min([signal.shape[1], data['audio'].shape[1]])
42
+ signal = signal[:,:min_len]
43
+ data['audio'] = data['audio'][:,:min_len]
44
+
45
+ # RTF
46
+ run_time = ed_time - st_time
47
+ song_time = data['audio'].shape[-1] / args.data.sampling_rate
48
+ rtf = run_time / song_time
49
+ print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
50
+ rtf_all.append(rtf)
51
+
52
+ # loss
53
+ loss = loss_func(signal, data['audio'])
54
+
55
+ test_loss += loss.item()
56
+
57
+ # log
58
+ saver.log_audio({fn+'/gt.wav': data['audio'], fn+'/pred.wav': signal})
59
+
60
+ # report
61
+ test_loss /= num_batches
62
+
63
+ # check
64
+ print(' [test_loss] test_loss:', test_loss)
65
+ print(' Real Time Factor', np.mean(rtf_all))
66
+ return test_loss
67
+
68
+
69
+ def train(args, initial_global_step, model, optimizer, loss_func, loader_train, loader_test):
70
+ # saver
71
+ saver = Saver(args, initial_global_step=initial_global_step)
72
+
73
+ # model size
74
+ params_count = utils.get_network_paras_amount({'model': model})
75
+ saver.log_info('--- model size ---')
76
+ saver.log_info(params_count)
77
+
78
+ # run
79
+ best_loss = np.inf
80
+ num_batches = len(loader_train)
81
+ model.train()
82
+ saver.log_info('======= start training =======')
83
+ for epoch in range(args.train.epochs):
84
+ for batch_idx, data in enumerate(loader_train):
85
+ saver.global_step_increment()
86
+ optimizer.zero_grad()
87
+
88
+ # unpack data
89
+ for k in data.keys():
90
+ if k != 'name':
91
+ data[k] = data[k].to(args.device)
92
+
93
+ # forward
94
+ signal, _, (s_h, s_n) = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], infer=False)
95
+
96
+ # loss
97
+ loss = loss_func(signal, data['audio'])
98
+
99
+ # handle nan loss
100
+ if torch.isnan(loss):
101
+ raise ValueError(' [x] nan loss ')
102
+ else:
103
+ # backpropagate
104
+ loss.backward()
105
+ optimizer.step()
106
+
107
+ # log loss
108
+ if saver.global_step % args.train.interval_log == 0:
109
+ saver.log_info(
110
+ 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | loss: {:.3f} | time: {} | step: {}'.format(
111
+ epoch,
112
+ batch_idx,
113
+ num_batches,
114
+ args.env.expdir,
115
+ args.train.interval_log/saver.get_interval_time(),
116
+ loss.item(),
117
+ saver.get_total_time(),
118
+ saver.global_step
119
+ )
120
+ )
121
+
122
+ saver.log_value({
123
+ 'train/loss': loss.item()
124
+ })
125
+
126
+ # validation
127
+ if saver.global_step % args.train.interval_val == 0:
128
+ # save latest
129
+ saver.save_model(model, optimizer, postfix=f'{saver.global_step}')
130
+
131
+ # run testing set
132
+ test_loss = test(args, model, loss_func, loader_test, saver)
133
+
134
+ saver.log_info(
135
+ ' --- <validation> --- \nloss: {:.3f}. '.format(
136
+ test_loss,
137
+ )
138
+ )
139
+
140
+ saver.log_value({
141
+ 'validation/loss': test_loss
142
+ })
143
+ model.train()
144
+
145
+ # save best model
146
+ if test_loss < best_loss:
147
+ saver.log_info(' [V] best model updated.')
148
+ saver.save_model(model, optimizer, postfix='best')
149
+ best_loss = test_loss
150
+
151
+
DDSP-SVC/train.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+
5
+ from logger import utils
6
+ from data_loaders import get_data_loaders
7
+ from solver import train
8
+ from ddsp.vocoder import Sins, CombSub, CombSubFast
9
+ from ddsp.loss import RSSLoss
10
+
11
+
12
+ def parse_args(args=None, namespace=None):
13
+ """Parse command-line arguments."""
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "-c",
17
+ "--config",
18
+ type=str,
19
+ required=True,
20
+ help="path to the config file")
21
+ return parser.parse_args(args=args, namespace=namespace)
22
+
23
+
24
+ if __name__ == '__main__':
25
+ # parse commands
26
+ cmd = parse_args()
27
+
28
+ # load config
29
+ args = utils.load_config(cmd.config)
30
+ print(' > config:', cmd.config)
31
+ print(' > exp:', args.env.expdir)
32
+
33
+ # load model
34
+ model = None
35
+
36
+ if args.model.type == 'Sins':
37
+ model = Sins(
38
+ sampling_rate=args.data.sampling_rate,
39
+ block_size=args.data.block_size,
40
+ n_harmonics=args.model.n_harmonics,
41
+ n_mag_allpass=args.model.n_mag_allpass,
42
+ n_mag_noise=args.model.n_mag_noise,
43
+ n_unit=args.data.encoder_out_channels,
44
+ n_spk=args.model.n_spk)
45
+
46
+ elif args.model.type == 'CombSub':
47
+ model = CombSub(
48
+ sampling_rate=args.data.sampling_rate,
49
+ block_size=args.data.block_size,
50
+ n_mag_allpass=args.model.n_mag_allpass,
51
+ n_mag_harmonic=args.model.n_mag_harmonic,
52
+ n_mag_noise=args.model.n_mag_noise,
53
+ n_unit=args.data.encoder_out_channels,
54
+ n_spk=args.model.n_spk)
55
+
56
+ elif args.model.type == 'CombSubFast':
57
+ model = CombSubFast(
58
+ sampling_rate=args.data.sampling_rate,
59
+ block_size=args.data.block_size,
60
+ n_unit=args.data.encoder_out_channels,
61
+ n_spk=args.model.n_spk)
62
+
63
+ else:
64
+ raise ValueError(f" [x] Unknown Model: {args.model.type}")
65
+
66
+ # load parameters
67
+ optimizer = torch.optim.AdamW(model.parameters())
68
+ initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device)
69
+ for param_group in optimizer.param_groups:
70
+ param_group['lr'] = args.train.lr
71
+ param_group['weight_decay'] = args.train.weight_decay
72
+
73
+ # loss
74
+ loss_func = RSSLoss(args.loss.fft_min, args.loss.fft_max, args.loss.n_scale, device = args.device)
75
+
76
+ # device
77
+ if args.device == 'cuda':
78
+ torch.cuda.set_device(args.env.gpu_id)
79
+ model.to(args.device)
80
+
81
+ for state in optimizer.state.values():
82
+ for k, v in state.items():
83
+ if torch.is_tensor(v):
84
+ state[k] = v.to(args.device)
85
+
86
+ loss_func.to(args.device)
87
+
88
+ # datas
89
+ loader_train, loader_valid = get_data_loaders(args, whole_audio=False)
90
+
91
+ # run
92
+ train(args, initial_global_step, model, optimizer, loss_func, loader_train, loader_valid)
93
+
DDSP-SVC/train_diff.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from torch.optim import lr_scheduler
5
+ from logger import utils
6
+ from diffusion.data_loaders import get_data_loaders
7
+ from diffusion.solver import train
8
+ from diffusion.unit2mel import Unit2Mel
9
+ from diffusion.vocoder import Vocoder
10
+
11
+
12
+ def parse_args(args=None, namespace=None):
13
+ """Parse command-line arguments."""
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "-c",
17
+ "--config",
18
+ type=str,
19
+ required=True,
20
+ help="path to the config file")
21
+ return parser.parse_args(args=args, namespace=namespace)
22
+
23
+
24
+ if __name__ == '__main__':
25
+ # parse commands
26
+ cmd = parse_args()
27
+
28
+ # load config
29
+ args = utils.load_config(cmd.config)
30
+ print(' > config:', cmd.config)
31
+ print(' > exp:', args.env.expdir)
32
+
33
+ # load vocoder
34
+ vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device)
35
+
36
+ # load model
37
+ model = Unit2Mel(
38
+ args.data.encoder_out_channels,
39
+ args.model.n_spk,
40
+ args.model.use_pitch_aug,
41
+ vocoder.dimension,
42
+ args.model.n_layers,
43
+ args.model.n_chans,
44
+ args.model.n_hidden)
45
+
46
+
47
+ # load parameters
48
+ optimizer = torch.optim.AdamW(model.parameters())
49
+ initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device)
50
+ for param_group in optimizer.param_groups:
51
+ param_group['lr'] = args.train.lr
52
+ param_group['weight_decay'] = args.train.weight_decay
53
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma)
54
+
55
+ # device
56
+ if args.device == 'cuda':
57
+ torch.cuda.set_device(args.env.gpu_id)
58
+ model.to(args.device)
59
+
60
+ for state in optimizer.state.values():
61
+ for k, v in state.items():
62
+ if torch.is_tensor(v):
63
+ state[k] = v.to(args.device)
64
+
65
+ # datas
66
+ loader_train, loader_valid = get_data_loaders(args, whole_audio=False)
67
+
68
+ # run
69
+ train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid)
70
+
DDSP-SVC/webui.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os,subprocess,yaml
3
+
4
+ class WebUI:
5
+ def __init__(self) -> None:
6
+ self.info=Info()
7
+ self.opt_cfg_pth='configs/opt.yaml'
8
+ self.main_ui()
9
+
10
+ def main_ui(self):
11
+ with gr.Blocks() as ui:
12
+ gr.Markdown('## 一个便于训练和推理的DDSP-webui,每一步的说明在下面,可以自己展开看。')
13
+ with gr.Tab("训练/Training"):
14
+ gr.Markdown(self.info.general)
15
+ with gr.Accordion('预训练模型说明',open=False):
16
+ gr.Markdown(self.info.pretrain_model)
17
+ with gr.Accordion('数据集说明',open=False):
18
+ gr.Markdown(self.info.dataset)
19
+
20
+ gr.Markdown('## 生成配置文件')
21
+ with gr.Row():
22
+ self.batch_size=gr.Slider(minimum=2,maximum=60,value=24,label='Batch_size',interactive=True)
23
+ self.learning_rate=gr.Number(value=0.0005,label='学习率',info='和batch_size关系大概是0.0001:6')
24
+ self.f0_extractor=gr.Dropdown(['parselmouth', 'dio', 'harvest', 'crepe'],type='value',value='crepe',label='f0提取器种类',interactive=True)
25
+ self.sampling_rate=gr.Number(value=44100,label='采样率',info='数据集音频的采样率',interactive=True)
26
+ self.n_spk=gr.Number(value=1,label='说话人数量',interactive=True)
27
+ with gr.Row():
28
+ self.device=gr.Dropdown(['cuda','cpu'],value='cuda',label='使用设备',interactive=True)
29
+ self.num_workers=gr.Number(value=2,label='读取数据进程数',info='如果你的设备性能很好,可以设置为0',interactive=True)
30
+ self.cache_all_data=gr.Checkbox(value=True,label='启用缓存',info='将数据全部加载以加速训练',interactive=True)
31
+ self.cache_device=gr.Dropdown(['cuda','cpu'],value='cuda',type='value',label='缓存设备',info='如果你的显存比较大,设置为cuda',interactive=True)
32
+ self.bt_create_config=gr.Button(value='创建配置文件')
33
+
34
+ gr.Markdown('## 预处理')
35
+ with gr.Accordion('预训练说明',open=False):
36
+ gr.Markdown(self.info.preprocess)
37
+ with gr.Row():
38
+ self.bt_open_data_folder=gr.Button('打开数据集文件夹')
39
+ self.bt_preprocess=gr.Button('开始预处理')
40
+ gr.Markdown('## 训练')
41
+ with gr.Accordion('训练说明',open=False):
42
+ gr.Markdown(self.info.train)
43
+ with gr.Row():
44
+ self.bt_train=gr.Button('开始训练')
45
+ self.bt_visual=gr.Button('启动可视化')
46
+ gr.Markdown('启动可视化后[点击打开](http://127.0.0.1:6006)')
47
+
48
+ with gr.Tab('推理/Inference'):
49
+ with gr.Accordion('推理说明',open=False):
50
+ gr.Markdown(self.info.infer)
51
+ with gr.Row():
52
+ self.input_wav=gr.Audio(type='filepath',label='选择待转换音频')
53
+ self.choose_model=gr.Textbox('exp/model_chino.pt',label='模型路径')
54
+ with gr.Row():
55
+ self.keychange=gr.Slider(-24,24,value=0,step=1,label='变调')
56
+ self.id=gr.Number(value=1,label='说话人id')
57
+ self.enhancer_adaptive_key=gr.Number(value=0,label='增强器音区偏移',info='调高可以防止超高音(比如大于G5) 破音,但是低音效果可能会下降')
58
+ with gr.Row():
59
+ self.bt_infer=gr.Button(value='开始转换')
60
+ self.output_wav=gr.Audio(type='filepath',label='输出音频')
61
+
62
+ self.bt_create_config.click(fn=self.create_config)
63
+ self.bt_open_data_folder.click(fn=self.openfolder)
64
+ self.bt_preprocess.click(fn=self.preprocess)
65
+ self.bt_train.click(fn=self.training)
66
+ self.bt_visual.click(fn=self.visualize)
67
+ self.bt_infer.click(fn=self.inference,inputs=[self.input_wav,self.choose_model,self.keychange,self.id,self.enhancer_adaptive_key],outputs=self.output_wav)
68
+ ui.launch(inbrowser=True,server_port=7858)
69
+
70
+ def openfolder(self):
71
+ try:
72
+ os.startfile('data')
73
+ except:
74
+ print('Fail to open folder!')
75
+
76
+
77
+ def create_config(self):
78
+ with open('configs/combsub.yaml','r',encoding='utf-8') as f:
79
+ cfg=yaml.load(f.read(),Loader=yaml.FullLoader)
80
+ cfg['data']['f0_extractor']=str(self.f0_extractor.value)
81
+ cfg['data']['sampling_rate']=int(self.sampling_rate.value)
82
+ cfg['train']['batch_size']=int(self.batch_size.value)
83
+ cfg['device']=str(self.device.value)
84
+ cfg['train']['num_workers']=int(self.num_workers.value)
85
+ cfg['train']['cache_all_data']=str(self.cache_all_data.value)
86
+ cfg['train']['cache_device']=str(self.cache_device.value)
87
+ cfg['train']['lr']=int(self.learning_rate.value)
88
+ print('配置文件信息:'+str(cfg))
89
+ with open(self.opt_cfg_pth,'w',encoding='utf-8') as f:
90
+ yaml.dump(cfg,f)
91
+ print('成功生成配置文件')
92
+
93
+
94
+ def preprocess(self):
95
+ preprocessing_process=subprocess.Popen('python -u preprocess.py -c '+self.opt_cfg_pth,stdout=subprocess.PIPE)
96
+ while preprocessing_process.poll() is None:
97
+ output=preprocessing_process.stdout.readline().decode('utf-8')
98
+ print(output)
99
+ print('预处理完成')
100
+
101
+ def training(self):
102
+ train_process=subprocess.Popen('python -u train.py -c '+self.opt_cfg_pth,stdout=subprocess.PIPE)
103
+ while train_process.poll() is None:
104
+ output=train_process.stdout.readline().decode('utf-8')
105
+ print(output)
106
+
107
+
108
+ def visualize(self):
109
+ tb_process=subprocess.Popen('tensorboard --logdir=exp --port=6006',stdout=subprocess.PIPE)
110
+ while tb_process.poll() is None:
111
+ output=tb_process.stdout.readline().decode('utf-8')
112
+ print(output)
113
+
114
+ def inference(self,input_wav:str,model:str,keychange,id,enhancer_adaptive_key):
115
+ print(input_wav,model)
116
+ output_wav='samples/'+ input_wav.replace('\\','/').split('/')[-1]
117
+ cmd='python -u main.py -i '+input_wav+' -m '+model+' -o '+output_wav+' -k '+str(int(keychange))+' -id '+str(int(id))+' -e true -eak '+str(int(enhancer_adaptive_key))
118
+ infer_process=subprocess.Popen(cmd,stdout=subprocess.PIPE)
119
+ while infer_process.poll() is None:
120
+ output=infer_process.stdout.readline().decode('utf-8')
121
+ print(output)
122
+ print('推理完成')
123
+ return output_wav
124
+
125
+
126
+ class Info:
127
+ def __init__(self) -> None:
128
+ self.general='''
129
+ ### 不看也没事,大致就是
130
+ 1.设置好配置之后点击创建配置文件
131
+ 2.点击‘打开数据集文件夹’,把数据集选个十个塞到data\\train\\val目录下面,剩下的音频全塞到data\\train\\audio下面
132
+ 3.点击‘开始预处理’等待执行完毕
133
+ 4.点击‘开始训练’和‘启动可视化’然后点击右侧链接
134
+ '''
135
+ self.pretrain_model="""
136
+ - **(必要操作)** 下载预训练 [**HubertSoft**](https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt) 编码器并将其放到 `pretrain/hubert` 文件夹。
137
+ - 更新:现在支持 ContentVec 编码器了。你可以下载预训练 [ContentVec](https://ibm.ent.box.com/s/z1wgl1stco8ffooyatzdwsqn2psd9lrr) 编码器替代 HubertSoft 编码器并修改配置文件以使用它。
138
+ - 从 [DiffSinger 社区声码器项目](https://openvpi.github.io/vocoders) 下载基于预训练声码器的增强器,并解压至 `pretrain/` 文件夹。
139
+ - 注意:你应当下载名称中带有`nsf_hifigan`的压缩文件,而非`nsf_hifigan_finetune`。
140
+ """
141
+ self.dataset="""
142
+ ### 1. 配置训练数据集和验证数据集
143
+
144
+ #### 1.1 手动配置:
145
+
146
+ 将所有的训练集数据 (.wav 格式音频切片) 放到 `data/train/audio`。
147
+
148
+ 将所有的验证集数据 (.wav 格式音频切片) 放到 `data/val/audio`。
149
+
150
+ #### 1.2 程序随机选择(**多人物时不可使用**):
151
+
152
+ 运行`python draw.py`,程序将帮助你挑选验证集数据(可以调整 `draw.py` 中的参数修改抽取文件的数量等参数)。
153
+
154
+ #### 1.3文件夹结构目录展示:
155
+ - 单人物目录结构:
156
+
157
+ ```
158
+ data
159
+ ├─ train
160
+ │ ├─ audio
161
+ │ │ ├─ aaa.wav
162
+ │ │ ├─ bbb.wav
163
+ │ │ └─ ....wav
164
+ │ └─ val
165
+ │ │ ├─ eee.wav
166
+ │ │ ├─ fff.wav
167
+ │ │ └─ ....wav
168
+ ```
169
+ - 多人物目录结构:
170
+
171
+ ```
172
+ data
173
+ ├─ train
174
+ │ ├─ audio
175
+ │ │ ├─ 1
176
+ │ │ │ ├─ aaa.wav
177
+ │ │ │ ├─ bbb.wav
178
+ │ │ │ └─ ....wav
179
+ │ │ ├─ 2
180
+ │ │ │ ├─ ccc.wav
181
+ │ │ │ ├─ ddd.wav
182
+ │ │ │ └─ ....wav
183
+ │ │ └─ ...
184
+ │ └─ val
185
+ │ │ ├─ 1
186
+ │ │ │ ├─ eee.wav
187
+ │ │ │ ├─ fff.wav
188
+ │ │ │ └─ ....wav
189
+ │ │ ├─ 2
190
+ │ │ │ ├─ ggg.wav
191
+ │ │ │ ├─ hhh.wav
192
+ │ │ │ └─ ....wav
193
+ │ │ └─ ...
194
+ ```
195
+ """
196
+ self.preprocess='''
197
+ 您可以在预处理之前修改配置文件 `config/<model_name>.yaml`,默认配置适用于GTX-1660 显卡训练 44.1khz 高采样率合成器。
198
+ ### 备注:
199
+ 1. 请保持所有音频切片的采样率与 yaml 配置文件中的采样率一致!如果不一致,程序可以跑,但训练过程中的重新采样将非常缓慢。(可选:使用Adobe Audition™的响度匹配功能可以一次性完成重采样修改声道和响度匹配。)
200
+
201
+ 2. 训练数据集的音频切片���数建议为约 1000 个,另外长音频切成小段可以加快训练速度,但所有音频切片的时长不应少于 2 秒。如果音频切片太多,则需要较大的内存,配置文件中将 `cache_all_data` 选项设置为 false 可以解决此问题。
202
+
203
+ 3. 验证集的音频切片总数建议为 10 个左右,不要放太多,不然验证过程会很慢。
204
+
205
+ 4. 如果您的数据集质量不是很高,请在配置文件中将 'f0_extractor' 设为 'crepe'。crepe 算法的抗噪性最好,但代价是会极大增加数据预处理所需的时间。
206
+
207
+ 5. 配置文件中的 ‘n_spk’ 参数将控制是否训练多说话人模型。如果您要训练**多说话人**模型,为了对说话人进行编号,所有音频文件夹的名称必须是**不大于 ‘n_spk’ 的正整数**。
208
+ '''
209
+ self.train='''
210
+ ## 训练
211
+
212
+ ### 1. 不使用预训练数据进行训练:
213
+ ```bash
214
+ # 以训练 combsub 模型为例
215
+ python train.py -c configs/combsub.yaml
216
+ ```
217
+ 1. 训练其他模型方法类似。
218
+
219
+ 2. 可以随时中止训练,然后运行相同的命令来继续训练。
220
+
221
+ 3. 微调 (finetune):在中止训练后,重新预处理新数据集或更改训练参数(batchsize、lr等),然后运行相同的命令。
222
+ ### 2. 使用预训练数据(底模)进行训练:
223
+ 1. **使用预训练模型请修改配置文件中的 'n_spk' 参数为 '2' ,同时配置`train`目录结构为多人物目录,不论你是否训练多说话人模型。**
224
+ 2. **如果你要训练一个更多说话人的模型,就不要下载预训练模型了。**
225
+ 3. 欢迎PR训练的多人底模 (请使用授权同意开源的数据集进行训练)。
226
+ 4. 从[**这里**](https://github.com/yxlllc/DDSP-SVC/releases/download/2.0/opencpop+kiritan.zip)下载预训练模型,并将`model_300000.pt`解压到`.\exp\combsub-test\`中
227
+ 5. 同不使用预训练数据进行训练一样,启动训练。
228
+ '''
229
+ self.visualize='''
230
+ ## 可视化
231
+ ```bash
232
+ # 使用tensorboard检查训练状态
233
+ tensorboard --logdir=exp
234
+ ```
235
+ 第一次验证 (validation) 后,在 TensorBoard 中可以看到合成后的测试音频。
236
+
237
+ 注:TensorBoard 中的测试音频是 DDSP-SVC 模型的原始输出,并未通过增强器增强。
238
+ '''
239
+ self.infer='''
240
+ ## 非实时变声
241
+ 1. (**推荐**)使用预训练声码器增强 DDSP 的输出结果:
242
+ ```bash
243
+ # 默认 enhancer_adaptive_key = 0 正常音域范围内将有更高的音质
244
+ # 设置 enhancer_adaptive_key > 0 可将增强器适配于更高的音域
245
+ python main.py -i <input.wav> -m <model_file.pt> -o <output.wav> -k <keychange (semitones)> -id <speaker_id> -e true -eak <enhancer_adaptive_key (semitones)>
246
+ ```
247
+ 2. DDSP 的原始输出结果:
248
+ ```bash
249
+ # 速度快,但音质相对较低(像您在tensorboard里听到的那样)
250
+ python main.py -i <input.wav> -m <model_file.pt> -o <output.wav> -k <keychange (semitones)> -e false -id <speaker_id>
251
+ ```
252
+ 3. 关于 f0 提取器、响应阈值及其他参数,参见:
253
+
254
+ ```bash
255
+ python main.py -h
256
+ ```
257
+ 4. 如果要使用混合说话人(捏音色)功能,增添 “-mix” 选项来设计音色,下面是个例子:
258
+ ```bash
259
+ # 将1号说话人和2号说话人的音色按照0.5:0.5的比例混合
260
+ python main.py -i <input.wav> -m <model_file.pt> -o <output.wav> -k <keychange (semitones)> -mix "{1:0.5, 2:0.5}" -e true -eak 0
261
+ ```
262
+ '''
263
+
264
+
265
+
266
+
267
+ webui=WebUI()