Spaces:
Sleeping
Sleeping
yuancwang
commited on
Commit
·
dce1ab4
1
Parent(s):
5548515
add models
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +18 -0
- models/__init__.py +0 -0
- models/base/__init__.py +7 -0
- models/base/base_dataset.py +344 -0
- models/base/base_inference.py +220 -0
- models/base/base_sampler.py +136 -0
- models/base/base_trainer.py +348 -0
- models/base/new_dataset.py +50 -0
- models/base/new_inference.py +249 -0
- models/base/new_trainer.py +722 -0
- models/svc/__init__.py +0 -0
- models/svc/base/__init__.py +7 -0
- models/svc/base/svc_dataset.py +437 -0
- models/svc/base/svc_inference.py +15 -0
- models/svc/base/svc_trainer.py +111 -0
- models/svc/comosvc/__init__.py +4 -0
- models/svc/comosvc/comosvc.py +377 -0
- models/svc/comosvc/comosvc_inference.py +39 -0
- models/svc/comosvc/comosvc_trainer.py +295 -0
- models/svc/comosvc/utils.py +31 -0
- models/svc/diffusion/__init__.py +0 -0
- models/svc/diffusion/diffusion_inference.py +63 -0
- models/svc/diffusion/diffusion_inference_pipeline.py +47 -0
- models/svc/diffusion/diffusion_trainer.py +88 -0
- models/svc/diffusion/diffusion_wrapper.py +73 -0
- models/svc/transformer/__init__.py +0 -0
- models/svc/transformer/conformer.py +405 -0
- models/svc/transformer/transformer.py +82 -0
- models/svc/transformer/transformer_inference.py +45 -0
- models/svc/transformer/transformer_trainer.py +52 -0
- models/svc/vits/__init__.py +0 -0
- models/svc/vits/vits.py +271 -0
- models/svc/vits/vits_inference.py +84 -0
- models/svc/vits/vits_trainer.py +483 -0
- models/tta/autoencoder/__init__.py +0 -0
- models/tta/autoencoder/autoencoder.py +405 -0
- models/tta/autoencoder/autoencoder_dataset.py +114 -0
- models/tta/autoencoder/autoencoder_loss.py +305 -0
- models/tta/autoencoder/autoencoder_trainer.py +187 -0
- models/tta/ldm/__init__.py +0 -0
- models/tta/ldm/attention.py +329 -0
- models/tta/ldm/audioldm.py +926 -0
- models/tta/ldm/audioldm_dataset.py +153 -0
- models/tta/ldm/audioldm_inference.py +193 -0
- models/tta/ldm/audioldm_trainer.py +251 -0
- models/tta/ldm/inference_utils/utils.py +62 -0
- models/tta/ldm/inference_utils/vocoder.py +408 -0
- models/tts/naturalspeech2/ns2_dataset.py +0 -2
- models/vocoders/autoregressive/autoregressive_vocoder_dataset.py +0 -0
- models/vocoders/autoregressive/autoregressive_vocoder_inference.py +0 -0
app.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch.nn as nn
|
7 |
+
from collections import OrderedDict
|
8 |
+
import json
|
9 |
+
|
10 |
+
from models.tta.autoencoder.autoencoder import AutoencoderKL
|
11 |
+
from models.tta.ldm.inference_utils.vocoder import Generator
|
12 |
+
from models.tta.ldm.audioldm import AudioLDM
|
13 |
+
from transformers import T5EncoderModel, AutoTokenizer
|
14 |
+
from diffusers import PNDMScheduler
|
15 |
+
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
from scipy.io.wavfile import write
|
18 |
+
|
models/__init__.py
ADDED
File without changes
|
models/base/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .new_trainer import BaseTrainer
|
7 |
+
from .new_inference import BaseInference
|
models/base/base_dataset.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import torch.utils.data
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
+
from utils.data_utils import *
|
11 |
+
from processors.acoustic_extractor import cal_normalized_mel
|
12 |
+
from text import text_to_sequence
|
13 |
+
from text.text_token_collation import phoneIDCollation
|
14 |
+
|
15 |
+
|
16 |
+
class BaseDataset(torch.utils.data.Dataset):
|
17 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
18 |
+
"""
|
19 |
+
Args:
|
20 |
+
cfg: config
|
21 |
+
dataset: dataset name
|
22 |
+
is_valid: whether to use train or valid dataset
|
23 |
+
"""
|
24 |
+
|
25 |
+
assert isinstance(dataset, str)
|
26 |
+
|
27 |
+
# self.data_root = processed_data_dir
|
28 |
+
self.cfg = cfg
|
29 |
+
|
30 |
+
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
31 |
+
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
32 |
+
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
33 |
+
self.metadata = self.get_metadata()
|
34 |
+
|
35 |
+
"""
|
36 |
+
load spk2id and utt2spk from json file
|
37 |
+
spk2id: {spk1: 0, spk2: 1, ...}
|
38 |
+
utt2spk: {dataset_uid: spk1, ...}
|
39 |
+
"""
|
40 |
+
if cfg.preprocess.use_spkid:
|
41 |
+
spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
|
42 |
+
with open(spk2id_path, "r") as f:
|
43 |
+
self.spk2id = json.load(f)
|
44 |
+
|
45 |
+
utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
|
46 |
+
self.utt2spk = dict()
|
47 |
+
with open(utt2spk_path, "r") as f:
|
48 |
+
for line in f.readlines():
|
49 |
+
utt, spk = line.strip().split("\t")
|
50 |
+
self.utt2spk[utt] = spk
|
51 |
+
|
52 |
+
if cfg.preprocess.use_uv:
|
53 |
+
self.utt2uv_path = {}
|
54 |
+
for utt_info in self.metadata:
|
55 |
+
dataset = utt_info["Dataset"]
|
56 |
+
uid = utt_info["Uid"]
|
57 |
+
utt = "{}_{}".format(dataset, uid)
|
58 |
+
self.utt2uv_path[utt] = os.path.join(
|
59 |
+
cfg.preprocess.processed_dir,
|
60 |
+
dataset,
|
61 |
+
cfg.preprocess.uv_dir,
|
62 |
+
uid + ".npy",
|
63 |
+
)
|
64 |
+
|
65 |
+
if cfg.preprocess.use_frame_pitch:
|
66 |
+
self.utt2frame_pitch_path = {}
|
67 |
+
for utt_info in self.metadata:
|
68 |
+
dataset = utt_info["Dataset"]
|
69 |
+
uid = utt_info["Uid"]
|
70 |
+
utt = "{}_{}".format(dataset, uid)
|
71 |
+
|
72 |
+
self.utt2frame_pitch_path[utt] = os.path.join(
|
73 |
+
cfg.preprocess.processed_dir,
|
74 |
+
dataset,
|
75 |
+
cfg.preprocess.pitch_dir,
|
76 |
+
uid + ".npy",
|
77 |
+
)
|
78 |
+
|
79 |
+
if cfg.preprocess.use_frame_energy:
|
80 |
+
self.utt2frame_energy_path = {}
|
81 |
+
for utt_info in self.metadata:
|
82 |
+
dataset = utt_info["Dataset"]
|
83 |
+
uid = utt_info["Uid"]
|
84 |
+
utt = "{}_{}".format(dataset, uid)
|
85 |
+
|
86 |
+
self.utt2frame_energy_path[utt] = os.path.join(
|
87 |
+
cfg.preprocess.processed_dir,
|
88 |
+
dataset,
|
89 |
+
cfg.preprocess.energy_dir,
|
90 |
+
uid + ".npy",
|
91 |
+
)
|
92 |
+
|
93 |
+
if cfg.preprocess.use_mel:
|
94 |
+
self.utt2mel_path = {}
|
95 |
+
for utt_info in self.metadata:
|
96 |
+
dataset = utt_info["Dataset"]
|
97 |
+
uid = utt_info["Uid"]
|
98 |
+
utt = "{}_{}".format(dataset, uid)
|
99 |
+
|
100 |
+
self.utt2mel_path[utt] = os.path.join(
|
101 |
+
cfg.preprocess.processed_dir,
|
102 |
+
dataset,
|
103 |
+
cfg.preprocess.mel_dir,
|
104 |
+
uid + ".npy",
|
105 |
+
)
|
106 |
+
|
107 |
+
if cfg.preprocess.use_linear:
|
108 |
+
self.utt2linear_path = {}
|
109 |
+
for utt_info in self.metadata:
|
110 |
+
dataset = utt_info["Dataset"]
|
111 |
+
uid = utt_info["Uid"]
|
112 |
+
utt = "{}_{}".format(dataset, uid)
|
113 |
+
|
114 |
+
self.utt2linear_path[utt] = os.path.join(
|
115 |
+
cfg.preprocess.processed_dir,
|
116 |
+
dataset,
|
117 |
+
cfg.preprocess.linear_dir,
|
118 |
+
uid + ".npy",
|
119 |
+
)
|
120 |
+
|
121 |
+
if cfg.preprocess.use_audio:
|
122 |
+
self.utt2audio_path = {}
|
123 |
+
for utt_info in self.metadata:
|
124 |
+
dataset = utt_info["Dataset"]
|
125 |
+
uid = utt_info["Uid"]
|
126 |
+
utt = "{}_{}".format(dataset, uid)
|
127 |
+
|
128 |
+
self.utt2audio_path[utt] = os.path.join(
|
129 |
+
cfg.preprocess.processed_dir,
|
130 |
+
dataset,
|
131 |
+
cfg.preprocess.audio_dir,
|
132 |
+
uid + ".npy",
|
133 |
+
)
|
134 |
+
elif cfg.preprocess.use_label:
|
135 |
+
self.utt2label_path = {}
|
136 |
+
for utt_info in self.metadata:
|
137 |
+
dataset = utt_info["Dataset"]
|
138 |
+
uid = utt_info["Uid"]
|
139 |
+
utt = "{}_{}".format(dataset, uid)
|
140 |
+
|
141 |
+
self.utt2label_path[utt] = os.path.join(
|
142 |
+
cfg.preprocess.processed_dir,
|
143 |
+
dataset,
|
144 |
+
cfg.preprocess.label_dir,
|
145 |
+
uid + ".npy",
|
146 |
+
)
|
147 |
+
elif cfg.preprocess.use_one_hot:
|
148 |
+
self.utt2one_hot_path = {}
|
149 |
+
for utt_info in self.metadata:
|
150 |
+
dataset = utt_info["Dataset"]
|
151 |
+
uid = utt_info["Uid"]
|
152 |
+
utt = "{}_{}".format(dataset, uid)
|
153 |
+
|
154 |
+
self.utt2one_hot_path[utt] = os.path.join(
|
155 |
+
cfg.preprocess.processed_dir,
|
156 |
+
dataset,
|
157 |
+
cfg.preprocess.one_hot_dir,
|
158 |
+
uid + ".npy",
|
159 |
+
)
|
160 |
+
|
161 |
+
if cfg.preprocess.use_text or cfg.preprocess.use_phone:
|
162 |
+
self.utt2seq = {}
|
163 |
+
for utt_info in self.metadata:
|
164 |
+
dataset = utt_info["Dataset"]
|
165 |
+
uid = utt_info["Uid"]
|
166 |
+
utt = "{}_{}".format(dataset, uid)
|
167 |
+
|
168 |
+
if cfg.preprocess.use_text:
|
169 |
+
text = utt_info["Text"]
|
170 |
+
sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
|
171 |
+
elif cfg.preprocess.use_phone:
|
172 |
+
# load phoneme squence from phone file
|
173 |
+
phone_path = os.path.join(
|
174 |
+
processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
|
175 |
+
)
|
176 |
+
with open(phone_path, "r") as fin:
|
177 |
+
phones = fin.readlines()
|
178 |
+
assert len(phones) == 1
|
179 |
+
phones = phones[0].strip()
|
180 |
+
phones_seq = phones.split(" ")
|
181 |
+
|
182 |
+
phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
|
183 |
+
sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
|
184 |
+
|
185 |
+
self.utt2seq[utt] = sequence
|
186 |
+
|
187 |
+
def get_metadata(self):
|
188 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
189 |
+
metadata = json.load(f)
|
190 |
+
|
191 |
+
return metadata
|
192 |
+
|
193 |
+
def get_dataset_name(self):
|
194 |
+
return self.metadata[0]["Dataset"]
|
195 |
+
|
196 |
+
def __getitem__(self, index):
|
197 |
+
utt_info = self.metadata[index]
|
198 |
+
|
199 |
+
dataset = utt_info["Dataset"]
|
200 |
+
uid = utt_info["Uid"]
|
201 |
+
utt = "{}_{}".format(dataset, uid)
|
202 |
+
|
203 |
+
single_feature = dict()
|
204 |
+
|
205 |
+
if self.cfg.preprocess.use_spkid:
|
206 |
+
single_feature["spk_id"] = np.array(
|
207 |
+
[self.spk2id[self.utt2spk[utt]]], dtype=np.int32
|
208 |
+
)
|
209 |
+
|
210 |
+
if self.cfg.preprocess.use_mel:
|
211 |
+
mel = np.load(self.utt2mel_path[utt])
|
212 |
+
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
|
213 |
+
if self.cfg.preprocess.use_min_max_norm_mel:
|
214 |
+
# do mel norm
|
215 |
+
mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
|
216 |
+
|
217 |
+
if "target_len" not in single_feature.keys():
|
218 |
+
single_feature["target_len"] = mel.shape[1]
|
219 |
+
single_feature["mel"] = mel.T # [T, n_mels]
|
220 |
+
|
221 |
+
if self.cfg.preprocess.use_linear:
|
222 |
+
linear = np.load(self.utt2linear_path[utt])
|
223 |
+
if "target_len" not in single_feature.keys():
|
224 |
+
single_feature["target_len"] = linear.shape[1]
|
225 |
+
single_feature["linear"] = linear.T # [T, n_linear]
|
226 |
+
|
227 |
+
if self.cfg.preprocess.use_frame_pitch:
|
228 |
+
frame_pitch_path = self.utt2frame_pitch_path[utt]
|
229 |
+
frame_pitch = np.load(frame_pitch_path)
|
230 |
+
if "target_len" not in single_feature.keys():
|
231 |
+
single_feature["target_len"] = len(frame_pitch)
|
232 |
+
aligned_frame_pitch = align_length(
|
233 |
+
frame_pitch, single_feature["target_len"]
|
234 |
+
)
|
235 |
+
single_feature["frame_pitch"] = aligned_frame_pitch
|
236 |
+
|
237 |
+
if self.cfg.preprocess.use_uv:
|
238 |
+
frame_uv_path = self.utt2uv_path[utt]
|
239 |
+
frame_uv = np.load(frame_uv_path)
|
240 |
+
aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
|
241 |
+
aligned_frame_uv = [
|
242 |
+
0 if frame_uv else 1 for frame_uv in aligned_frame_uv
|
243 |
+
]
|
244 |
+
aligned_frame_uv = np.array(aligned_frame_uv)
|
245 |
+
single_feature["frame_uv"] = aligned_frame_uv
|
246 |
+
|
247 |
+
if self.cfg.preprocess.use_frame_energy:
|
248 |
+
frame_energy_path = self.utt2frame_energy_path[utt]
|
249 |
+
frame_energy = np.load(frame_energy_path)
|
250 |
+
if "target_len" not in single_feature.keys():
|
251 |
+
single_feature["target_len"] = len(frame_energy)
|
252 |
+
aligned_frame_energy = align_length(
|
253 |
+
frame_energy, single_feature["target_len"]
|
254 |
+
)
|
255 |
+
single_feature["frame_energy"] = aligned_frame_energy
|
256 |
+
|
257 |
+
if self.cfg.preprocess.use_audio:
|
258 |
+
audio = np.load(self.utt2audio_path[utt])
|
259 |
+
single_feature["audio"] = audio
|
260 |
+
single_feature["audio_len"] = audio.shape[0]
|
261 |
+
|
262 |
+
if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
|
263 |
+
single_feature["phone_seq"] = np.array(self.utt2seq[utt])
|
264 |
+
single_feature["phone_len"] = len(self.utt2seq[utt])
|
265 |
+
|
266 |
+
return single_feature
|
267 |
+
|
268 |
+
def __len__(self):
|
269 |
+
return len(self.metadata)
|
270 |
+
|
271 |
+
|
272 |
+
class BaseCollator(object):
|
273 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
274 |
+
|
275 |
+
def __init__(self, cfg):
|
276 |
+
self.cfg = cfg
|
277 |
+
|
278 |
+
def __call__(self, batch):
|
279 |
+
packed_batch_features = dict()
|
280 |
+
|
281 |
+
# mel: [b, T, n_mels]
|
282 |
+
# frame_pitch, frame_energy: [1, T]
|
283 |
+
# target_len: [1]
|
284 |
+
# spk_id: [b, 1]
|
285 |
+
# mask: [b, T, 1]
|
286 |
+
|
287 |
+
for key in batch[0].keys():
|
288 |
+
if key == "target_len":
|
289 |
+
packed_batch_features["target_len"] = torch.LongTensor(
|
290 |
+
[b["target_len"] for b in batch]
|
291 |
+
)
|
292 |
+
masks = [
|
293 |
+
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
|
294 |
+
]
|
295 |
+
packed_batch_features["mask"] = pad_sequence(
|
296 |
+
masks, batch_first=True, padding_value=0
|
297 |
+
)
|
298 |
+
elif key == "phone_len":
|
299 |
+
packed_batch_features["phone_len"] = torch.LongTensor(
|
300 |
+
[b["phone_len"] for b in batch]
|
301 |
+
)
|
302 |
+
masks = [
|
303 |
+
torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
|
304 |
+
]
|
305 |
+
packed_batch_features["phn_mask"] = pad_sequence(
|
306 |
+
masks, batch_first=True, padding_value=0
|
307 |
+
)
|
308 |
+
elif key == "audio_len":
|
309 |
+
packed_batch_features["audio_len"] = torch.LongTensor(
|
310 |
+
[b["audio_len"] for b in batch]
|
311 |
+
)
|
312 |
+
masks = [
|
313 |
+
torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
|
314 |
+
]
|
315 |
+
else:
|
316 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
317 |
+
packed_batch_features[key] = pad_sequence(
|
318 |
+
values, batch_first=True, padding_value=0
|
319 |
+
)
|
320 |
+
return packed_batch_features
|
321 |
+
|
322 |
+
|
323 |
+
class BaseTestDataset(torch.utils.data.Dataset):
|
324 |
+
def __init__(self, cfg, args):
|
325 |
+
raise NotImplementedError
|
326 |
+
|
327 |
+
def get_metadata(self):
|
328 |
+
raise NotImplementedError
|
329 |
+
|
330 |
+
def __getitem__(self, index):
|
331 |
+
raise NotImplementedError
|
332 |
+
|
333 |
+
def __len__(self):
|
334 |
+
return len(self.metadata)
|
335 |
+
|
336 |
+
|
337 |
+
class BaseTestCollator(object):
|
338 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
339 |
+
|
340 |
+
def __init__(self, cfg):
|
341 |
+
raise NotImplementedError
|
342 |
+
|
343 |
+
def __call__(self, batch):
|
344 |
+
raise NotImplementedError
|
models/base/base_inference.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import time
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from models.vocoders.vocoder_inference import synthesis
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from utils.util import set_all_random_seed
|
19 |
+
from utils.util import load_config
|
20 |
+
|
21 |
+
|
22 |
+
def parse_vocoder(vocoder_dir):
|
23 |
+
r"""Parse vocoder config"""
|
24 |
+
vocoder_dir = os.path.abspath(vocoder_dir)
|
25 |
+
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
|
26 |
+
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
|
27 |
+
ckpt_path = str(ckpt_list[0])
|
28 |
+
vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True)
|
29 |
+
vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder
|
30 |
+
return vocoder_cfg, ckpt_path
|
31 |
+
|
32 |
+
|
33 |
+
class BaseInference(object):
|
34 |
+
def __init__(self, cfg, args):
|
35 |
+
self.cfg = cfg
|
36 |
+
self.args = args
|
37 |
+
self.model_type = cfg.model_type
|
38 |
+
self.avg_rtf = list()
|
39 |
+
set_all_random_seed(10086)
|
40 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
41 |
+
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
self.device = torch.device("cuda")
|
44 |
+
else:
|
45 |
+
self.device = torch.device("cpu")
|
46 |
+
torch.set_num_threads(10) # inference on 1 core cpu.
|
47 |
+
|
48 |
+
# Load acoustic model
|
49 |
+
self.model = self.create_model().to(self.device)
|
50 |
+
state_dict = self.load_state_dict()
|
51 |
+
self.load_model(state_dict)
|
52 |
+
self.model.eval()
|
53 |
+
|
54 |
+
# Load vocoder model if necessary
|
55 |
+
if self.args.checkpoint_dir_vocoder is not None:
|
56 |
+
self.get_vocoder_info()
|
57 |
+
|
58 |
+
def create_model(self):
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
def load_state_dict(self):
|
62 |
+
self.checkpoint_file = self.args.checkpoint_file
|
63 |
+
if self.checkpoint_file is None:
|
64 |
+
assert self.args.checkpoint_dir is not None
|
65 |
+
checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint")
|
66 |
+
checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
|
67 |
+
self.checkpoint_file = os.path.join(
|
68 |
+
self.args.checkpoint_dir, checkpoint_filename
|
69 |
+
)
|
70 |
+
|
71 |
+
self.checkpoint_dir = os.path.split(self.checkpoint_file)[0]
|
72 |
+
|
73 |
+
print("Restore acoustic model from {}".format(self.checkpoint_file))
|
74 |
+
raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device)
|
75 |
+
self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0]
|
76 |
+
|
77 |
+
return raw_state_dict
|
78 |
+
|
79 |
+
def load_model(self, model):
|
80 |
+
raise NotImplementedError
|
81 |
+
|
82 |
+
def get_vocoder_info(self):
|
83 |
+
self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder
|
84 |
+
self.vocoder_cfg = os.path.join(
|
85 |
+
os.path.dirname(self.checkpoint_dir_vocoder), "args.json"
|
86 |
+
)
|
87 |
+
self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True)
|
88 |
+
self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1]
|
89 |
+
self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0]
|
90 |
+
|
91 |
+
def build_test_utt_data(self):
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
def build_testdata_loader(self, args, target_speaker=None):
|
95 |
+
datasets, collate = self.build_test_dataset()
|
96 |
+
self.test_dataset = datasets(self.cfg, args, target_speaker)
|
97 |
+
self.test_collate = collate(self.cfg)
|
98 |
+
self.test_batch_size = min(
|
99 |
+
self.cfg.train.batch_size, len(self.test_dataset.metadata)
|
100 |
+
)
|
101 |
+
test_loader = DataLoader(
|
102 |
+
self.test_dataset,
|
103 |
+
collate_fn=self.test_collate,
|
104 |
+
num_workers=self.args.num_workers,
|
105 |
+
batch_size=self.test_batch_size,
|
106 |
+
shuffle=False,
|
107 |
+
)
|
108 |
+
return test_loader
|
109 |
+
|
110 |
+
def inference_each_batch(self, batch_data):
|
111 |
+
raise NotImplementedError
|
112 |
+
|
113 |
+
def inference_for_batches(self, args, target_speaker=None):
|
114 |
+
###### Construct test_batch ######
|
115 |
+
loader = self.build_testdata_loader(args, target_speaker)
|
116 |
+
|
117 |
+
n_batch = len(loader)
|
118 |
+
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
|
119 |
+
print(
|
120 |
+
"Model eval time: {}, batch_size = {}, n_batch = {}".format(
|
121 |
+
now, self.test_batch_size, n_batch
|
122 |
+
)
|
123 |
+
)
|
124 |
+
self.model.eval()
|
125 |
+
|
126 |
+
###### Inference for each batch ######
|
127 |
+
pred_res = []
|
128 |
+
with torch.no_grad():
|
129 |
+
for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)):
|
130 |
+
# Put the data to device
|
131 |
+
for k, v in batch_data.items():
|
132 |
+
batch_data[k] = batch_data[k].to(self.device)
|
133 |
+
|
134 |
+
y_pred, stats = self.inference_each_batch(batch_data)
|
135 |
+
|
136 |
+
pred_res += y_pred
|
137 |
+
|
138 |
+
return pred_res
|
139 |
+
|
140 |
+
def inference(self, feature):
|
141 |
+
raise NotImplementedError
|
142 |
+
|
143 |
+
def synthesis_by_vocoder(self, pred):
|
144 |
+
audios_pred = synthesis(
|
145 |
+
self.vocoder_cfg,
|
146 |
+
self.checkpoint_dir_vocoder,
|
147 |
+
len(pred),
|
148 |
+
pred,
|
149 |
+
)
|
150 |
+
return audios_pred
|
151 |
+
|
152 |
+
def __call__(self, utt):
|
153 |
+
feature = self.build_test_utt_data(utt)
|
154 |
+
start_time = time.time()
|
155 |
+
with torch.no_grad():
|
156 |
+
outputs = self.inference(feature)[0]
|
157 |
+
time_used = time.time() - start_time
|
158 |
+
rtf = time_used / (
|
159 |
+
outputs.shape[1]
|
160 |
+
* self.cfg.preprocess.hop_size
|
161 |
+
/ self.cfg.preprocess.sample_rate
|
162 |
+
)
|
163 |
+
print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf))
|
164 |
+
self.avg_rtf.append(rtf)
|
165 |
+
audios = outputs.cpu().squeeze().numpy().reshape(-1, 1)
|
166 |
+
return audios
|
167 |
+
|
168 |
+
|
169 |
+
def base_parser():
|
170 |
+
parser = argparse.ArgumentParser()
|
171 |
+
parser.add_argument(
|
172 |
+
"--config", default="config.json", help="json files for configurations."
|
173 |
+
)
|
174 |
+
parser.add_argument("--use_ddp_inference", default=False)
|
175 |
+
parser.add_argument("--n_workers", default=1, type=int)
|
176 |
+
parser.add_argument("--local_rank", default=-1, type=int)
|
177 |
+
parser.add_argument(
|
178 |
+
"--batch_size", default=1, type=int, help="Batch size for inference"
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--num_workers",
|
182 |
+
default=1,
|
183 |
+
type=int,
|
184 |
+
help="Worker number for inference dataloader",
|
185 |
+
)
|
186 |
+
parser.add_argument(
|
187 |
+
"--checkpoint_dir",
|
188 |
+
type=str,
|
189 |
+
default=None,
|
190 |
+
help="Checkpoint dir including model file and configuration",
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--checkpoint_file", help="checkpoint file", type=str, default=None
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
"--test_list", help="test utterance list for testing", type=str, default=None
|
197 |
+
)
|
198 |
+
parser.add_argument(
|
199 |
+
"--checkpoint_dir_vocoder",
|
200 |
+
help="Vocoder's checkpoint dir including model file and configuration",
|
201 |
+
type=str,
|
202 |
+
default=None,
|
203 |
+
)
|
204 |
+
parser.add_argument(
|
205 |
+
"--output_dir",
|
206 |
+
type=str,
|
207 |
+
default=None,
|
208 |
+
help="Output dir for saving generated results",
|
209 |
+
)
|
210 |
+
return parser
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == "__main__":
|
214 |
+
parser = base_parser()
|
215 |
+
args = parser.parse_args()
|
216 |
+
cfg = load_config(args.config)
|
217 |
+
|
218 |
+
# Build inference
|
219 |
+
inference = BaseInference(cfg, args)
|
220 |
+
inference()
|
models/base/base_sampler.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
|
9 |
+
from torch.utils.data import ConcatDataset, Dataset
|
10 |
+
from torch.utils.data.sampler import (
|
11 |
+
BatchSampler,
|
12 |
+
RandomSampler,
|
13 |
+
Sampler,
|
14 |
+
SequentialSampler,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class ScheduledSampler(Sampler):
|
19 |
+
"""A sampler that samples data from a given concat-dataset.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
|
23 |
+
batch_size (int): batch size
|
24 |
+
holistic_shuffle (bool): whether to shuffle the whole dataset or not
|
25 |
+
logger (logging.Logger): logger to print warning message
|
26 |
+
|
27 |
+
Usage:
|
28 |
+
For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
|
29 |
+
>>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
|
30 |
+
[3, 4, 5, 0, 1, 2, 6, 7, 8]
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
concat_dataset,
|
36 |
+
batch_size,
|
37 |
+
holistic_shuffle,
|
38 |
+
logger=None,
|
39 |
+
loader_type="train",
|
40 |
+
):
|
41 |
+
if not isinstance(concat_dataset, ConcatDataset):
|
42 |
+
raise ValueError(
|
43 |
+
"concat_dataset must be an instance of ConcatDataset, but got {}".format(
|
44 |
+
type(concat_dataset)
|
45 |
+
)
|
46 |
+
)
|
47 |
+
if not isinstance(batch_size, int):
|
48 |
+
raise ValueError(
|
49 |
+
"batch_size must be an integer, but got {}".format(type(batch_size))
|
50 |
+
)
|
51 |
+
if not isinstance(holistic_shuffle, bool):
|
52 |
+
raise ValueError(
|
53 |
+
"holistic_shuffle must be a boolean, but got {}".format(
|
54 |
+
type(holistic_shuffle)
|
55 |
+
)
|
56 |
+
)
|
57 |
+
|
58 |
+
self.concat_dataset = concat_dataset
|
59 |
+
self.batch_size = batch_size
|
60 |
+
self.holistic_shuffle = holistic_shuffle
|
61 |
+
|
62 |
+
affected_dataset_name = []
|
63 |
+
affected_dataset_len = []
|
64 |
+
for dataset in concat_dataset.datasets:
|
65 |
+
dataset_len = len(dataset)
|
66 |
+
dataset_name = dataset.get_dataset_name()
|
67 |
+
if dataset_len < batch_size:
|
68 |
+
affected_dataset_name.append(dataset_name)
|
69 |
+
affected_dataset_len.append(dataset_len)
|
70 |
+
|
71 |
+
self.type = loader_type
|
72 |
+
for dataset_name, dataset_len in zip(
|
73 |
+
affected_dataset_name, affected_dataset_len
|
74 |
+
):
|
75 |
+
if not loader_type == "valid":
|
76 |
+
logger.warning(
|
77 |
+
"The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
|
78 |
+
loader_type, dataset_name, dataset_len, batch_size
|
79 |
+
)
|
80 |
+
)
|
81 |
+
|
82 |
+
def __len__(self):
|
83 |
+
# the number of batches with drop last
|
84 |
+
num_of_batches = sum(
|
85 |
+
[
|
86 |
+
math.floor(len(dataset) / self.batch_size)
|
87 |
+
for dataset in self.concat_dataset.datasets
|
88 |
+
]
|
89 |
+
)
|
90 |
+
# if samples are not enough for one batch, we don't drop last
|
91 |
+
if self.type == "valid" and num_of_batches < 1:
|
92 |
+
return len(self.concat_dataset)
|
93 |
+
return num_of_batches * self.batch_size
|
94 |
+
|
95 |
+
def __iter__(self):
|
96 |
+
iters = []
|
97 |
+
for dataset in self.concat_dataset.datasets:
|
98 |
+
iters.append(
|
99 |
+
SequentialSampler(dataset).__iter__()
|
100 |
+
if not self.holistic_shuffle
|
101 |
+
else RandomSampler(dataset).__iter__()
|
102 |
+
)
|
103 |
+
# e.g. [0, 200, 400]
|
104 |
+
init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
|
105 |
+
output_batches = []
|
106 |
+
for dataset_idx in range(len(self.concat_dataset.datasets)):
|
107 |
+
cur_batch = []
|
108 |
+
for idx in iters[dataset_idx]:
|
109 |
+
cur_batch.append(idx + init_indices[dataset_idx])
|
110 |
+
if len(cur_batch) == self.batch_size:
|
111 |
+
output_batches.append(cur_batch)
|
112 |
+
cur_batch = []
|
113 |
+
# if loader_type is valid, we don't need to drop last
|
114 |
+
if self.type == "valid" and len(cur_batch) > 0:
|
115 |
+
output_batches.append(cur_batch)
|
116 |
+
|
117 |
+
# force drop last in training
|
118 |
+
random.shuffle(output_batches)
|
119 |
+
output_indices = [item for sublist in output_batches for item in sublist]
|
120 |
+
return iter(output_indices)
|
121 |
+
|
122 |
+
|
123 |
+
def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type):
|
124 |
+
sampler = ScheduledSampler(
|
125 |
+
concat_dataset,
|
126 |
+
cfg.train.batch_size,
|
127 |
+
cfg.train.sampler.holistic_shuffle,
|
128 |
+
logger,
|
129 |
+
loader_type,
|
130 |
+
)
|
131 |
+
batch_sampler = BatchSampler(
|
132 |
+
sampler,
|
133 |
+
cfg.train.batch_size,
|
134 |
+
cfg.train.sampler.drop_last if not loader_type == "valid" else False,
|
135 |
+
)
|
136 |
+
return sampler, batch_sampler
|
models/base/base_trainer.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import collections
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
from torch.nn.parallel import DistributedDataParallel
|
15 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
|
18 |
+
from models.base.base_sampler import BatchSampler
|
19 |
+
from utils.util import (
|
20 |
+
Logger,
|
21 |
+
remove_older_ckpt,
|
22 |
+
save_config,
|
23 |
+
set_all_random_seed,
|
24 |
+
ValueWindow,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class BaseTrainer(object):
|
29 |
+
def __init__(self, args, cfg):
|
30 |
+
self.args = args
|
31 |
+
self.log_dir = args.log_dir
|
32 |
+
self.cfg = cfg
|
33 |
+
|
34 |
+
self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
|
35 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
36 |
+
if not cfg.train.ddp or args.local_rank == 0:
|
37 |
+
self.sw = SummaryWriter(os.path.join(args.log_dir, "events"))
|
38 |
+
self.logger = self.build_logger()
|
39 |
+
self.time_window = ValueWindow(50)
|
40 |
+
|
41 |
+
self.step = 0
|
42 |
+
self.epoch = -1
|
43 |
+
self.max_epochs = self.cfg.train.epochs
|
44 |
+
self.max_steps = self.cfg.train.max_steps
|
45 |
+
|
46 |
+
# set random seed & init distributed training
|
47 |
+
set_all_random_seed(self.cfg.train.random_seed)
|
48 |
+
if cfg.train.ddp:
|
49 |
+
dist.init_process_group(backend="nccl")
|
50 |
+
|
51 |
+
if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]:
|
52 |
+
self.singers = self.build_singers_lut()
|
53 |
+
|
54 |
+
# setup data_loader
|
55 |
+
self.data_loader = self.build_data_loader()
|
56 |
+
|
57 |
+
# setup model & enable distributed training
|
58 |
+
self.model = self.build_model()
|
59 |
+
print(self.model)
|
60 |
+
|
61 |
+
if isinstance(self.model, dict):
|
62 |
+
for key, value in self.model.items():
|
63 |
+
value.cuda(self.args.local_rank)
|
64 |
+
if key == "PQMF":
|
65 |
+
continue
|
66 |
+
if cfg.train.ddp:
|
67 |
+
self.model[key] = DistributedDataParallel(
|
68 |
+
value, device_ids=[self.args.local_rank]
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
self.model.cuda(self.args.local_rank)
|
72 |
+
if cfg.train.ddp:
|
73 |
+
self.model = DistributedDataParallel(
|
74 |
+
self.model, device_ids=[self.args.local_rank]
|
75 |
+
)
|
76 |
+
|
77 |
+
# create criterion
|
78 |
+
self.criterion = self.build_criterion()
|
79 |
+
if isinstance(self.criterion, dict):
|
80 |
+
for key, value in self.criterion.items():
|
81 |
+
self.criterion[key].cuda(args.local_rank)
|
82 |
+
else:
|
83 |
+
self.criterion.cuda(self.args.local_rank)
|
84 |
+
|
85 |
+
# optimizer
|
86 |
+
self.optimizer = self.build_optimizer()
|
87 |
+
self.scheduler = self.build_scheduler()
|
88 |
+
|
89 |
+
# save config file
|
90 |
+
self.config_save_path = os.path.join(self.checkpoint_dir, "args.json")
|
91 |
+
|
92 |
+
def build_logger(self):
|
93 |
+
log_file = os.path.join(self.checkpoint_dir, "train.log")
|
94 |
+
logger = Logger(log_file, level=self.args.log_level).logger
|
95 |
+
|
96 |
+
return logger
|
97 |
+
|
98 |
+
def build_dataset(self):
|
99 |
+
raise NotImplementedError
|
100 |
+
|
101 |
+
def build_data_loader(self):
|
102 |
+
Dataset, Collator = self.build_dataset()
|
103 |
+
# build dataset instance for each dataset and combine them by ConcatDataset
|
104 |
+
datasets_list = []
|
105 |
+
for dataset in self.cfg.dataset:
|
106 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
107 |
+
datasets_list.append(subdataset)
|
108 |
+
train_dataset = ConcatDataset(datasets_list)
|
109 |
+
|
110 |
+
train_collate = Collator(self.cfg)
|
111 |
+
# TODO: multi-GPU training
|
112 |
+
if self.cfg.train.ddp:
|
113 |
+
raise NotImplementedError("DDP is not supported yet.")
|
114 |
+
|
115 |
+
# sampler will provide indices to batch_sampler, which will perform batching and yield batch indices
|
116 |
+
batch_sampler = BatchSampler(
|
117 |
+
cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list
|
118 |
+
)
|
119 |
+
|
120 |
+
# use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
|
121 |
+
train_loader = DataLoader(
|
122 |
+
train_dataset,
|
123 |
+
collate_fn=train_collate,
|
124 |
+
num_workers=self.args.num_workers,
|
125 |
+
batch_sampler=batch_sampler,
|
126 |
+
pin_memory=False,
|
127 |
+
)
|
128 |
+
if not self.cfg.train.ddp or self.args.local_rank == 0:
|
129 |
+
datasets_list = []
|
130 |
+
for dataset in self.cfg.dataset:
|
131 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
132 |
+
datasets_list.append(subdataset)
|
133 |
+
valid_dataset = ConcatDataset(datasets_list)
|
134 |
+
valid_collate = Collator(self.cfg)
|
135 |
+
batch_sampler = BatchSampler(
|
136 |
+
cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list
|
137 |
+
)
|
138 |
+
valid_loader = DataLoader(
|
139 |
+
valid_dataset,
|
140 |
+
collate_fn=valid_collate,
|
141 |
+
num_workers=1,
|
142 |
+
batch_sampler=batch_sampler,
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
raise NotImplementedError("DDP is not supported yet.")
|
146 |
+
# valid_loader = None
|
147 |
+
data_loader = {"train": train_loader, "valid": valid_loader}
|
148 |
+
return data_loader
|
149 |
+
|
150 |
+
def build_singers_lut(self):
|
151 |
+
# combine singers
|
152 |
+
if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)):
|
153 |
+
singers = collections.OrderedDict()
|
154 |
+
else:
|
155 |
+
with open(
|
156 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r"
|
157 |
+
) as singer_file:
|
158 |
+
singers = json.load(singer_file)
|
159 |
+
singer_count = len(singers)
|
160 |
+
for dataset in self.cfg.dataset:
|
161 |
+
singer_lut_path = os.path.join(
|
162 |
+
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
|
163 |
+
)
|
164 |
+
with open(singer_lut_path, "r") as singer_lut_path:
|
165 |
+
singer_lut = json.load(singer_lut_path)
|
166 |
+
for singer in singer_lut.keys():
|
167 |
+
if singer not in singers:
|
168 |
+
singers[singer] = singer_count
|
169 |
+
singer_count += 1
|
170 |
+
with open(
|
171 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w"
|
172 |
+
) as singer_file:
|
173 |
+
json.dump(singers, singer_file, indent=4, ensure_ascii=False)
|
174 |
+
print(
|
175 |
+
"singers have been dumped to {}".format(
|
176 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id)
|
177 |
+
)
|
178 |
+
)
|
179 |
+
return singers
|
180 |
+
|
181 |
+
def build_model(self):
|
182 |
+
raise NotImplementedError()
|
183 |
+
|
184 |
+
def build_optimizer(self):
|
185 |
+
raise NotImplementedError
|
186 |
+
|
187 |
+
def build_scheduler(self):
|
188 |
+
raise NotImplementedError()
|
189 |
+
|
190 |
+
def build_criterion(self):
|
191 |
+
raise NotImplementedError
|
192 |
+
|
193 |
+
def get_state_dict(self):
|
194 |
+
raise NotImplementedError
|
195 |
+
|
196 |
+
def save_config_file(self):
|
197 |
+
save_config(self.config_save_path, self.cfg)
|
198 |
+
|
199 |
+
# TODO, save without module.
|
200 |
+
def save_checkpoint(self, state_dict, saved_model_path):
|
201 |
+
torch.save(state_dict, saved_model_path)
|
202 |
+
|
203 |
+
def load_checkpoint(self):
|
204 |
+
checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint")
|
205 |
+
assert os.path.exists(checkpoint_path)
|
206 |
+
checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
|
207 |
+
model_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
|
208 |
+
assert os.path.exists(model_path)
|
209 |
+
if not self.cfg.train.ddp or self.args.local_rank == 0:
|
210 |
+
self.logger.info(f"Re(store) from {model_path}")
|
211 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
212 |
+
return checkpoint
|
213 |
+
|
214 |
+
def load_model(self, checkpoint):
|
215 |
+
raise NotImplementedError
|
216 |
+
|
217 |
+
def restore(self):
|
218 |
+
checkpoint = self.load_checkpoint()
|
219 |
+
self.load_model(checkpoint)
|
220 |
+
|
221 |
+
def train_step(self, data):
|
222 |
+
raise NotImplementedError(
|
223 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
224 |
+
f"your sub-class of {self.__class__.__name__}. "
|
225 |
+
)
|
226 |
+
|
227 |
+
@torch.no_grad()
|
228 |
+
def eval_step(self):
|
229 |
+
raise NotImplementedError(
|
230 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
231 |
+
f"your sub-class of {self.__class__.__name__}. "
|
232 |
+
)
|
233 |
+
|
234 |
+
def write_summary(self, losses, stats):
|
235 |
+
raise NotImplementedError(
|
236 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
237 |
+
f"your sub-class of {self.__class__.__name__}. "
|
238 |
+
)
|
239 |
+
|
240 |
+
def write_valid_summary(self, losses, stats):
|
241 |
+
raise NotImplementedError(
|
242 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
243 |
+
f"your sub-class of {self.__class__.__name__}. "
|
244 |
+
)
|
245 |
+
|
246 |
+
def echo_log(self, losses, mode="Training"):
|
247 |
+
message = [
|
248 |
+
"{} - Epoch {} Step {}: [{:.3f} s/step]".format(
|
249 |
+
mode, self.epoch + 1, self.step, self.time_window.average
|
250 |
+
)
|
251 |
+
]
|
252 |
+
|
253 |
+
for key in sorted(losses.keys()):
|
254 |
+
if isinstance(losses[key], dict):
|
255 |
+
for k, v in losses[key].items():
|
256 |
+
message.append(
|
257 |
+
str(k).split("/")[-1] + "=" + str(round(float(v), 5))
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
message.append(
|
261 |
+
str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
|
262 |
+
)
|
263 |
+
self.logger.info(", ".join(message))
|
264 |
+
|
265 |
+
def eval_epoch(self):
|
266 |
+
self.logger.info("Validation...")
|
267 |
+
valid_losses = {}
|
268 |
+
for i, batch_data in enumerate(self.data_loader["valid"]):
|
269 |
+
for k, v in batch_data.items():
|
270 |
+
if isinstance(v, torch.Tensor):
|
271 |
+
batch_data[k] = v.cuda()
|
272 |
+
valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i)
|
273 |
+
for key in valid_loss:
|
274 |
+
if key not in valid_losses:
|
275 |
+
valid_losses[key] = 0
|
276 |
+
valid_losses[key] += valid_loss[key]
|
277 |
+
|
278 |
+
# Add mel and audio to the Tensorboard
|
279 |
+
# Average loss
|
280 |
+
for key in valid_losses:
|
281 |
+
valid_losses[key] /= i + 1
|
282 |
+
self.echo_log(valid_losses, "Valid")
|
283 |
+
return valid_losses, valid_stats
|
284 |
+
|
285 |
+
def train_epoch(self):
|
286 |
+
for i, batch_data in enumerate(self.data_loader["train"]):
|
287 |
+
start_time = time.time()
|
288 |
+
# Put the data to cuda device
|
289 |
+
for k, v in batch_data.items():
|
290 |
+
if isinstance(v, torch.Tensor):
|
291 |
+
batch_data[k] = v.cuda(self.args.local_rank)
|
292 |
+
|
293 |
+
# Training step
|
294 |
+
train_losses, train_stats, total_loss = self.train_step(batch_data)
|
295 |
+
self.time_window.append(time.time() - start_time)
|
296 |
+
|
297 |
+
if self.args.local_rank == 0 or not self.cfg.train.ddp:
|
298 |
+
if self.step % self.args.stdout_interval == 0:
|
299 |
+
self.echo_log(train_losses, "Training")
|
300 |
+
|
301 |
+
if self.step % self.cfg.train.save_summary_steps == 0:
|
302 |
+
self.logger.info(f"Save summary as step {self.step}")
|
303 |
+
self.write_summary(train_losses, train_stats)
|
304 |
+
|
305 |
+
if (
|
306 |
+
self.step % self.cfg.train.save_checkpoints_steps == 0
|
307 |
+
and self.step != 0
|
308 |
+
):
|
309 |
+
saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format(
|
310 |
+
self.step, total_loss
|
311 |
+
)
|
312 |
+
saved_model_path = os.path.join(
|
313 |
+
self.checkpoint_dir, saved_model_name
|
314 |
+
)
|
315 |
+
saved_state_dict = self.get_state_dict()
|
316 |
+
self.save_checkpoint(saved_state_dict, saved_model_path)
|
317 |
+
self.save_config_file()
|
318 |
+
# keep max n models
|
319 |
+
remove_older_ckpt(
|
320 |
+
saved_model_name,
|
321 |
+
self.checkpoint_dir,
|
322 |
+
max_to_keep=self.cfg.train.keep_checkpoint_max,
|
323 |
+
)
|
324 |
+
|
325 |
+
if self.step != 0 and self.step % self.cfg.train.valid_interval == 0:
|
326 |
+
if isinstance(self.model, dict):
|
327 |
+
for key in self.model.keys():
|
328 |
+
self.model[key].eval()
|
329 |
+
else:
|
330 |
+
self.model.eval()
|
331 |
+
# Evaluate one epoch and get average loss
|
332 |
+
valid_losses, valid_stats = self.eval_epoch()
|
333 |
+
if isinstance(self.model, dict):
|
334 |
+
for key in self.model.keys():
|
335 |
+
self.model[key].train()
|
336 |
+
else:
|
337 |
+
self.model.train()
|
338 |
+
# Write validation losses to summary.
|
339 |
+
self.write_valid_summary(valid_losses, valid_stats)
|
340 |
+
self.step += 1
|
341 |
+
|
342 |
+
def train(self):
|
343 |
+
for epoch in range(max(0, self.epoch), self.max_epochs):
|
344 |
+
self.train_epoch()
|
345 |
+
self.epoch += 1
|
346 |
+
if self.step > self.max_steps:
|
347 |
+
self.logger.info("Training finished!")
|
348 |
+
break
|
models/base/new_dataset.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
from abc import abstractmethod
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import json5
|
12 |
+
import torch
|
13 |
+
import yaml
|
14 |
+
|
15 |
+
|
16 |
+
# TODO: for training and validating
|
17 |
+
class BaseDataset(torch.utils.data.Dataset):
|
18 |
+
r"""Base dataset for training and validating."""
|
19 |
+
|
20 |
+
def __init__(self, args, cfg, is_valid=False):
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class BaseTestDataset(torch.utils.data.Dataset):
|
25 |
+
r"""Test dataset for inference."""
|
26 |
+
|
27 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
28 |
+
assert infer_type in ["from_dataset", "from_file"]
|
29 |
+
|
30 |
+
self.args = args
|
31 |
+
self.cfg = cfg
|
32 |
+
self.infer_type = infer_type
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def __getitem__(self, index):
|
36 |
+
pass
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.metadata)
|
40 |
+
|
41 |
+
def get_metadata(self):
|
42 |
+
path = Path(self.args.source)
|
43 |
+
if path.suffix == ".json" or path.suffix == ".jsonc":
|
44 |
+
metadata = json5.load(open(self.args.source, "r"))
|
45 |
+
elif path.suffix == ".yaml" or path.suffix == ".yml":
|
46 |
+
metadata = yaml.full_load(open(self.args.source, "r"))
|
47 |
+
else:
|
48 |
+
raise ValueError(f"Unsupported file type: {path.suffix}")
|
49 |
+
|
50 |
+
return metadata
|
models/base/new_inference.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
import time
|
10 |
+
from abc import abstractmethod
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import accelerate
|
14 |
+
import json5
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from accelerate.logging import get_logger
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
|
20 |
+
from models.vocoders.vocoder_inference import synthesis
|
21 |
+
from utils.io import save_audio
|
22 |
+
from utils.util import load_config
|
23 |
+
from utils.audio_slicer import is_silence
|
24 |
+
|
25 |
+
EPS = 1.0e-12
|
26 |
+
|
27 |
+
|
28 |
+
class BaseInference(object):
|
29 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
start = time.monotonic_ns()
|
33 |
+
self.args = args
|
34 |
+
self.cfg = cfg
|
35 |
+
|
36 |
+
assert infer_type in ["from_dataset", "from_file"]
|
37 |
+
self.infer_type = infer_type
|
38 |
+
|
39 |
+
# init with accelerate
|
40 |
+
self.accelerator = accelerate.Accelerator()
|
41 |
+
self.accelerator.wait_for_everyone()
|
42 |
+
|
43 |
+
# Use accelerate logger for distributed inference
|
44 |
+
with self.accelerator.main_process_first():
|
45 |
+
self.logger = get_logger("inference", log_level=args.log_level)
|
46 |
+
|
47 |
+
# Log some info
|
48 |
+
self.logger.info("=" * 56)
|
49 |
+
self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
|
50 |
+
self.logger.info("=" * 56)
|
51 |
+
self.logger.info("\n")
|
52 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
53 |
+
|
54 |
+
self.acoustics_dir = args.acoustics_dir
|
55 |
+
self.logger.debug(f"Acoustic dir: {args.acoustics_dir}")
|
56 |
+
self.vocoder_dir = args.vocoder_dir
|
57 |
+
self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
|
58 |
+
# should be in svc inferencer
|
59 |
+
# self.target_singer = args.target_singer
|
60 |
+
# self.logger.info(f"Target singers: {args.target_singer}")
|
61 |
+
# self.trans_key = args.trans_key
|
62 |
+
# self.logger.info(f"Trans key: {args.trans_key}")
|
63 |
+
|
64 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
65 |
+
|
66 |
+
# set random seed
|
67 |
+
with self.accelerator.main_process_first():
|
68 |
+
start = time.monotonic_ns()
|
69 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
70 |
+
end = time.monotonic_ns()
|
71 |
+
self.logger.debug(
|
72 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
73 |
+
)
|
74 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
75 |
+
|
76 |
+
# setup data_loader
|
77 |
+
with self.accelerator.main_process_first():
|
78 |
+
self.logger.info("Building dataset...")
|
79 |
+
start = time.monotonic_ns()
|
80 |
+
self.test_dataloader = self._build_dataloader()
|
81 |
+
end = time.monotonic_ns()
|
82 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
83 |
+
|
84 |
+
# setup model
|
85 |
+
with self.accelerator.main_process_first():
|
86 |
+
self.logger.info("Building model...")
|
87 |
+
start = time.monotonic_ns()
|
88 |
+
self.model = self._build_model()
|
89 |
+
end = time.monotonic_ns()
|
90 |
+
# self.logger.debug(self.model)
|
91 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
|
92 |
+
|
93 |
+
# init with accelerate
|
94 |
+
self.logger.info("Initializing accelerate...")
|
95 |
+
start = time.monotonic_ns()
|
96 |
+
self.accelerator = accelerate.Accelerator()
|
97 |
+
self.model = self.accelerator.prepare(self.model)
|
98 |
+
end = time.monotonic_ns()
|
99 |
+
self.accelerator.wait_for_everyone()
|
100 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
|
101 |
+
|
102 |
+
with self.accelerator.main_process_first():
|
103 |
+
self.logger.info("Loading checkpoint...")
|
104 |
+
start = time.monotonic_ns()
|
105 |
+
# TODO: Also, suppose only use latest one yet
|
106 |
+
self.__load_model(os.path.join(args.acoustics_dir, "checkpoint"))
|
107 |
+
end = time.monotonic_ns()
|
108 |
+
self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
|
109 |
+
|
110 |
+
self.model.eval()
|
111 |
+
self.accelerator.wait_for_everyone()
|
112 |
+
|
113 |
+
### Abstract methods ###
|
114 |
+
@abstractmethod
|
115 |
+
def _build_test_dataset(self):
|
116 |
+
pass
|
117 |
+
|
118 |
+
@abstractmethod
|
119 |
+
def _build_model(self):
|
120 |
+
pass
|
121 |
+
|
122 |
+
@abstractmethod
|
123 |
+
@torch.inference_mode()
|
124 |
+
def _inference_each_batch(self, batch_data):
|
125 |
+
pass
|
126 |
+
|
127 |
+
### Abstract methods end ###
|
128 |
+
|
129 |
+
@torch.inference_mode()
|
130 |
+
def inference(self):
|
131 |
+
for i, batch in enumerate(self.test_dataloader):
|
132 |
+
y_pred = self._inference_each_batch(batch).cpu()
|
133 |
+
mel_min, mel_max = self.test_dataset.target_mel_extrema
|
134 |
+
y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min
|
135 |
+
y_ls = y_pred.chunk(self.test_batch_size)
|
136 |
+
tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
|
137 |
+
j = 0
|
138 |
+
for it, l in zip(y_ls, tgt_ls):
|
139 |
+
l = l.item()
|
140 |
+
it = it.squeeze(0)[:l]
|
141 |
+
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
|
142 |
+
torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
|
143 |
+
j += 1
|
144 |
+
|
145 |
+
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
|
146 |
+
|
147 |
+
res = synthesis(
|
148 |
+
cfg=vocoder_cfg,
|
149 |
+
vocoder_weight_file=vocoder_ckpt,
|
150 |
+
n_samples=None,
|
151 |
+
pred=[
|
152 |
+
torch.load(
|
153 |
+
os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"]))
|
154 |
+
).numpy(force=True)
|
155 |
+
for i in self.test_dataset.metadata
|
156 |
+
],
|
157 |
+
)
|
158 |
+
|
159 |
+
output_audio_files = []
|
160 |
+
for it, wav in zip(self.test_dataset.metadata, res):
|
161 |
+
uid = it["Uid"]
|
162 |
+
file = os.path.join(self.args.output_dir, f"{uid}.wav")
|
163 |
+
output_audio_files.append(file)
|
164 |
+
|
165 |
+
wav = wav.numpy(force=True)
|
166 |
+
save_audio(
|
167 |
+
file,
|
168 |
+
wav,
|
169 |
+
self.cfg.preprocess.sample_rate,
|
170 |
+
add_silence=False,
|
171 |
+
turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
|
172 |
+
)
|
173 |
+
os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
|
174 |
+
|
175 |
+
return sorted(output_audio_files)
|
176 |
+
|
177 |
+
# TODO: LEGACY CODE
|
178 |
+
def _build_dataloader(self):
|
179 |
+
datasets, collate = self._build_test_dataset()
|
180 |
+
self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
|
181 |
+
self.test_collate = collate(self.cfg)
|
182 |
+
self.test_batch_size = min(
|
183 |
+
self.cfg.train.batch_size, len(self.test_dataset.metadata)
|
184 |
+
)
|
185 |
+
test_dataloader = DataLoader(
|
186 |
+
self.test_dataset,
|
187 |
+
collate_fn=self.test_collate,
|
188 |
+
num_workers=1,
|
189 |
+
batch_size=self.test_batch_size,
|
190 |
+
shuffle=False,
|
191 |
+
)
|
192 |
+
return test_dataloader
|
193 |
+
|
194 |
+
def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None):
|
195 |
+
r"""Load model from checkpoint. If checkpoint_path is None, it will
|
196 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
197 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
198 |
+
method after** ``accelerator.prepare()``.
|
199 |
+
"""
|
200 |
+
if checkpoint_path is None:
|
201 |
+
ls = []
|
202 |
+
for i in Path(checkpoint_dir).iterdir():
|
203 |
+
if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)):
|
204 |
+
ls.append(i)
|
205 |
+
ls.sort(
|
206 |
+
key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True
|
207 |
+
)
|
208 |
+
checkpoint_path = ls[0]
|
209 |
+
else:
|
210 |
+
checkpoint_path = Path(checkpoint_path)
|
211 |
+
self.accelerator.load_state(str(checkpoint_path))
|
212 |
+
# set epoch and step
|
213 |
+
self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1])
|
214 |
+
self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1])
|
215 |
+
return str(checkpoint_path)
|
216 |
+
|
217 |
+
@staticmethod
|
218 |
+
def _set_random_seed(seed):
|
219 |
+
r"""Set random seed for all possible random modules."""
|
220 |
+
random.seed(seed)
|
221 |
+
np.random.seed(seed)
|
222 |
+
torch.random.manual_seed(seed)
|
223 |
+
|
224 |
+
@staticmethod
|
225 |
+
def _parse_vocoder(vocoder_dir):
|
226 |
+
r"""Parse vocoder config"""
|
227 |
+
vocoder_dir = os.path.abspath(vocoder_dir)
|
228 |
+
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
|
229 |
+
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
|
230 |
+
ckpt_path = str(ckpt_list[0])
|
231 |
+
vocoder_cfg = load_config(
|
232 |
+
os.path.join(vocoder_dir, "args.json"), lowercase=True
|
233 |
+
)
|
234 |
+
return vocoder_cfg, ckpt_path
|
235 |
+
|
236 |
+
@staticmethod
|
237 |
+
def __count_parameters(model):
|
238 |
+
return sum(p.numel() for p in model.parameters())
|
239 |
+
|
240 |
+
def __dump_cfg(self, path):
|
241 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
242 |
+
json5.dump(
|
243 |
+
self.cfg,
|
244 |
+
open(path, "w"),
|
245 |
+
indent=4,
|
246 |
+
sort_keys=True,
|
247 |
+
ensure_ascii=False,
|
248 |
+
quote_keys=True,
|
249 |
+
)
|
models/base/new_trainer.py
ADDED
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import shutil
|
10 |
+
import time
|
11 |
+
from abc import abstractmethod
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import accelerate
|
15 |
+
import json5
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from accelerate.logging import get_logger
|
19 |
+
from accelerate.utils import ProjectConfiguration
|
20 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from models.base.base_sampler import build_samplers
|
24 |
+
from optimizer.optimizers import NoamLR
|
25 |
+
|
26 |
+
|
27 |
+
class BaseTrainer(object):
|
28 |
+
r"""The base trainer for all tasks. Any trainer should inherit from this class."""
|
29 |
+
|
30 |
+
def __init__(self, args=None, cfg=None):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.args = args
|
34 |
+
self.cfg = cfg
|
35 |
+
|
36 |
+
cfg.exp_name = args.exp_name
|
37 |
+
|
38 |
+
# init with accelerate
|
39 |
+
self._init_accelerator()
|
40 |
+
self.accelerator.wait_for_everyone()
|
41 |
+
|
42 |
+
# Use accelerate logger for distributed training
|
43 |
+
with self.accelerator.main_process_first():
|
44 |
+
self.logger = get_logger(args.exp_name, log_level=args.log_level)
|
45 |
+
|
46 |
+
# Log some info
|
47 |
+
self.logger.info("=" * 56)
|
48 |
+
self.logger.info("||\t\t" + "New training process started." + "\t\t||")
|
49 |
+
self.logger.info("=" * 56)
|
50 |
+
self.logger.info("\n")
|
51 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
52 |
+
self.logger.info(f"Experiment name: {args.exp_name}")
|
53 |
+
self.logger.info(f"Experiment directory: {self.exp_dir}")
|
54 |
+
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
|
55 |
+
if self.accelerator.is_main_process:
|
56 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
57 |
+
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
|
58 |
+
|
59 |
+
# init counts
|
60 |
+
self.batch_count: int = 0
|
61 |
+
self.step: int = 0
|
62 |
+
self.epoch: int = 0
|
63 |
+
self.max_epoch = (
|
64 |
+
self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
|
65 |
+
)
|
66 |
+
self.logger.info(
|
67 |
+
"Max epoch: {}".format(
|
68 |
+
self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
|
69 |
+
)
|
70 |
+
)
|
71 |
+
|
72 |
+
# Check values
|
73 |
+
if self.accelerator.is_main_process:
|
74 |
+
self.__check_basic_configs()
|
75 |
+
# Set runtime configs
|
76 |
+
self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
|
77 |
+
self.checkpoints_path = [
|
78 |
+
[] for _ in range(len(self.save_checkpoint_stride))
|
79 |
+
]
|
80 |
+
self.keep_last = [
|
81 |
+
i if i > 0 else float("inf") for i in self.cfg.train.keep_last
|
82 |
+
]
|
83 |
+
self.run_eval = self.cfg.train.run_eval
|
84 |
+
|
85 |
+
# set random seed
|
86 |
+
with self.accelerator.main_process_first():
|
87 |
+
start = time.monotonic_ns()
|
88 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
89 |
+
end = time.monotonic_ns()
|
90 |
+
self.logger.debug(
|
91 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
92 |
+
)
|
93 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
94 |
+
|
95 |
+
# setup data_loader
|
96 |
+
with self.accelerator.main_process_first():
|
97 |
+
self.logger.info("Building dataset...")
|
98 |
+
start = time.monotonic_ns()
|
99 |
+
self.train_dataloader, self.valid_dataloader = self._build_dataloader()
|
100 |
+
end = time.monotonic_ns()
|
101 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
102 |
+
|
103 |
+
# setup model
|
104 |
+
with self.accelerator.main_process_first():
|
105 |
+
self.logger.info("Building model...")
|
106 |
+
start = time.monotonic_ns()
|
107 |
+
self.model = self._build_model()
|
108 |
+
end = time.monotonic_ns()
|
109 |
+
self.logger.debug(self.model)
|
110 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
|
111 |
+
self.logger.info(
|
112 |
+
f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
|
113 |
+
)
|
114 |
+
# optimizer & scheduler
|
115 |
+
with self.accelerator.main_process_first():
|
116 |
+
self.logger.info("Building optimizer and scheduler...")
|
117 |
+
start = time.monotonic_ns()
|
118 |
+
self.optimizer = self.__build_optimizer()
|
119 |
+
self.scheduler = self.__build_scheduler()
|
120 |
+
end = time.monotonic_ns()
|
121 |
+
self.logger.info(
|
122 |
+
f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
|
123 |
+
)
|
124 |
+
|
125 |
+
# accelerate prepare
|
126 |
+
self.logger.info("Initializing accelerate...")
|
127 |
+
start = time.monotonic_ns()
|
128 |
+
(
|
129 |
+
self.train_dataloader,
|
130 |
+
self.valid_dataloader,
|
131 |
+
self.model,
|
132 |
+
self.optimizer,
|
133 |
+
self.scheduler,
|
134 |
+
) = self.accelerator.prepare(
|
135 |
+
self.train_dataloader,
|
136 |
+
self.valid_dataloader,
|
137 |
+
self.model,
|
138 |
+
self.optimizer,
|
139 |
+
self.scheduler,
|
140 |
+
)
|
141 |
+
end = time.monotonic_ns()
|
142 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
|
143 |
+
|
144 |
+
# create criterion
|
145 |
+
with self.accelerator.main_process_first():
|
146 |
+
self.logger.info("Building criterion...")
|
147 |
+
start = time.monotonic_ns()
|
148 |
+
self.criterion = self._build_criterion()
|
149 |
+
end = time.monotonic_ns()
|
150 |
+
self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
|
151 |
+
|
152 |
+
# Resume or Finetune
|
153 |
+
with self.accelerator.main_process_first():
|
154 |
+
if args.resume:
|
155 |
+
## Automatically resume according to the current exprimental name
|
156 |
+
self.logger.info("Resuming from {}...".format(self.checkpoint_dir))
|
157 |
+
start = time.monotonic_ns()
|
158 |
+
ckpt_path = self.__load_model(
|
159 |
+
checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
|
160 |
+
)
|
161 |
+
end = time.monotonic_ns()
|
162 |
+
self.logger.info(
|
163 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
164 |
+
)
|
165 |
+
self.checkpoints_path = json.load(
|
166 |
+
open(os.path.join(ckpt_path, "ckpts.json"), "r")
|
167 |
+
)
|
168 |
+
elif args.resume_from_ckpt_path and args.resume_from_ckpt_path != "":
|
169 |
+
## Resume from the given checkpoint path
|
170 |
+
if not os.path.exists(args.resume_from_ckpt_path):
|
171 |
+
raise ValueError(
|
172 |
+
"[Error] The resumed checkpoint path {} don't exist.".format(
|
173 |
+
args.resume_from_ckpt_path
|
174 |
+
)
|
175 |
+
)
|
176 |
+
|
177 |
+
self.logger.info(
|
178 |
+
"Resuming from {}...".format(args.resume_from_ckpt_path)
|
179 |
+
)
|
180 |
+
start = time.monotonic_ns()
|
181 |
+
ckpt_path = self.__load_model(
|
182 |
+
checkpoint_path=args.resume_from_ckpt_path,
|
183 |
+
resume_type=args.resume_type,
|
184 |
+
)
|
185 |
+
end = time.monotonic_ns()
|
186 |
+
self.logger.info(
|
187 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
188 |
+
)
|
189 |
+
|
190 |
+
# save config file path
|
191 |
+
self.config_save_path = os.path.join(self.exp_dir, "args.json")
|
192 |
+
|
193 |
+
### Following are abstract methods that should be implemented in child classes ###
|
194 |
+
@abstractmethod
|
195 |
+
def _build_dataset(self):
|
196 |
+
r"""Build dataset for model training/validating/evaluating."""
|
197 |
+
pass
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
@abstractmethod
|
201 |
+
def _build_criterion():
|
202 |
+
r"""Build criterion function for model loss calculation."""
|
203 |
+
pass
|
204 |
+
|
205 |
+
@abstractmethod
|
206 |
+
def _build_model(self):
|
207 |
+
r"""Build model for training/validating/evaluating."""
|
208 |
+
pass
|
209 |
+
|
210 |
+
@abstractmethod
|
211 |
+
def _forward_step(self, batch):
|
212 |
+
r"""One forward step of the neural network. This abstract method is trying to
|
213 |
+
unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
|
214 |
+
However, for special case that using different forward step pattern for
|
215 |
+
training and validating, you could just override this method with ``pass`` and
|
216 |
+
implement ``_train_step`` and ``_valid_step`` separately.
|
217 |
+
"""
|
218 |
+
pass
|
219 |
+
|
220 |
+
@abstractmethod
|
221 |
+
def _save_auxiliary_states(self):
|
222 |
+
r"""To save some auxiliary states when saving model's ckpt"""
|
223 |
+
pass
|
224 |
+
|
225 |
+
### Abstract methods end ###
|
226 |
+
|
227 |
+
### THIS IS MAIN ENTRY ###
|
228 |
+
def train_loop(self):
|
229 |
+
r"""Training loop. The public entry of training process."""
|
230 |
+
# Wait everyone to prepare before we move on
|
231 |
+
self.accelerator.wait_for_everyone()
|
232 |
+
# dump config file
|
233 |
+
if self.accelerator.is_main_process:
|
234 |
+
self.__dump_cfg(self.config_save_path)
|
235 |
+
self.model.train()
|
236 |
+
self.optimizer.zero_grad()
|
237 |
+
# Wait to ensure good to go
|
238 |
+
self.accelerator.wait_for_everyone()
|
239 |
+
while self.epoch < self.max_epoch:
|
240 |
+
self.logger.info("\n")
|
241 |
+
self.logger.info("-" * 32)
|
242 |
+
self.logger.info("Epoch {}: ".format(self.epoch))
|
243 |
+
|
244 |
+
### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
|
245 |
+
### It's inconvenient for the model with multiple losses
|
246 |
+
# Do training & validating epoch
|
247 |
+
train_loss = self._train_epoch()
|
248 |
+
self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
|
249 |
+
valid_loss = self._valid_epoch()
|
250 |
+
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
|
251 |
+
self.accelerator.log(
|
252 |
+
{"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
|
253 |
+
step=self.epoch,
|
254 |
+
)
|
255 |
+
|
256 |
+
self.accelerator.wait_for_everyone()
|
257 |
+
# TODO: what is scheduler?
|
258 |
+
self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
|
259 |
+
|
260 |
+
# Check if hit save_checkpoint_stride and run_eval
|
261 |
+
run_eval = False
|
262 |
+
if self.accelerator.is_main_process:
|
263 |
+
save_checkpoint = False
|
264 |
+
hit_dix = []
|
265 |
+
for i, num in enumerate(self.save_checkpoint_stride):
|
266 |
+
if self.epoch % num == 0:
|
267 |
+
save_checkpoint = True
|
268 |
+
hit_dix.append(i)
|
269 |
+
run_eval |= self.run_eval[i]
|
270 |
+
|
271 |
+
self.accelerator.wait_for_everyone()
|
272 |
+
if self.accelerator.is_main_process and save_checkpoint:
|
273 |
+
path = os.path.join(
|
274 |
+
self.checkpoint_dir,
|
275 |
+
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
276 |
+
self.epoch, self.step, train_loss
|
277 |
+
),
|
278 |
+
)
|
279 |
+
self.tmp_checkpoint_save_path = path
|
280 |
+
self.accelerator.save_state(path)
|
281 |
+
print(f"save checkpoint in {path}")
|
282 |
+
json.dump(
|
283 |
+
self.checkpoints_path,
|
284 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
285 |
+
ensure_ascii=False,
|
286 |
+
indent=4,
|
287 |
+
)
|
288 |
+
self._save_auxiliary_states()
|
289 |
+
|
290 |
+
# Remove old checkpoints
|
291 |
+
to_remove = []
|
292 |
+
for idx in hit_dix:
|
293 |
+
self.checkpoints_path[idx].append(path)
|
294 |
+
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
|
295 |
+
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
|
296 |
+
|
297 |
+
# Search conflicts
|
298 |
+
total = set()
|
299 |
+
for i in self.checkpoints_path:
|
300 |
+
total |= set(i)
|
301 |
+
do_remove = set()
|
302 |
+
for idx, path in to_remove[::-1]:
|
303 |
+
if path in total:
|
304 |
+
self.checkpoints_path[idx].insert(0, path)
|
305 |
+
else:
|
306 |
+
do_remove.add(path)
|
307 |
+
|
308 |
+
# Remove old checkpoints
|
309 |
+
for path in do_remove:
|
310 |
+
shutil.rmtree(path, ignore_errors=True)
|
311 |
+
self.logger.debug(f"Remove old checkpoint: {path}")
|
312 |
+
|
313 |
+
self.accelerator.wait_for_everyone()
|
314 |
+
if run_eval:
|
315 |
+
# TODO: run evaluation
|
316 |
+
pass
|
317 |
+
|
318 |
+
# Update info for each epoch
|
319 |
+
self.epoch += 1
|
320 |
+
|
321 |
+
# Finish training and save final checkpoint
|
322 |
+
self.accelerator.wait_for_everyone()
|
323 |
+
if self.accelerator.is_main_process:
|
324 |
+
self.accelerator.save_state(
|
325 |
+
os.path.join(
|
326 |
+
self.checkpoint_dir,
|
327 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
328 |
+
self.epoch, self.step, valid_loss
|
329 |
+
),
|
330 |
+
)
|
331 |
+
)
|
332 |
+
self._save_auxiliary_states()
|
333 |
+
|
334 |
+
self.accelerator.end_training()
|
335 |
+
|
336 |
+
### Following are methods that can be used directly in child classes ###
|
337 |
+
def _train_epoch(self):
|
338 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
339 |
+
one epoch. See ``train_loop`` for usage.
|
340 |
+
"""
|
341 |
+
self.model.train()
|
342 |
+
epoch_sum_loss: float = 0.0
|
343 |
+
epoch_step: int = 0
|
344 |
+
for batch in tqdm(
|
345 |
+
self.train_dataloader,
|
346 |
+
desc=f"Training Epoch {self.epoch}",
|
347 |
+
unit="batch",
|
348 |
+
colour="GREEN",
|
349 |
+
leave=False,
|
350 |
+
dynamic_ncols=True,
|
351 |
+
smoothing=0.04,
|
352 |
+
disable=not self.accelerator.is_main_process,
|
353 |
+
):
|
354 |
+
# Do training step and BP
|
355 |
+
with self.accelerator.accumulate(self.model):
|
356 |
+
loss = self._train_step(batch)
|
357 |
+
self.accelerator.backward(loss)
|
358 |
+
self.optimizer.step()
|
359 |
+
self.optimizer.zero_grad()
|
360 |
+
self.batch_count += 1
|
361 |
+
|
362 |
+
# Update info for each step
|
363 |
+
# TODO: step means BP counts or batch counts?
|
364 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
365 |
+
epoch_sum_loss += loss
|
366 |
+
self.accelerator.log(
|
367 |
+
{
|
368 |
+
"Step/Train Loss": loss,
|
369 |
+
"Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
|
370 |
+
},
|
371 |
+
step=self.step,
|
372 |
+
)
|
373 |
+
self.step += 1
|
374 |
+
epoch_step += 1
|
375 |
+
|
376 |
+
self.accelerator.wait_for_everyone()
|
377 |
+
return (
|
378 |
+
epoch_sum_loss
|
379 |
+
/ len(self.train_dataloader)
|
380 |
+
* self.cfg.train.gradient_accumulation_step
|
381 |
+
)
|
382 |
+
|
383 |
+
@torch.inference_mode()
|
384 |
+
def _valid_epoch(self):
|
385 |
+
r"""Testing epoch. Should return average loss of a batch (sample) over
|
386 |
+
one epoch. See ``train_loop`` for usage.
|
387 |
+
"""
|
388 |
+
self.model.eval()
|
389 |
+
epoch_sum_loss = 0.0
|
390 |
+
for batch in tqdm(
|
391 |
+
self.valid_dataloader,
|
392 |
+
desc=f"Validating Epoch {self.epoch}",
|
393 |
+
unit="batch",
|
394 |
+
colour="GREEN",
|
395 |
+
leave=False,
|
396 |
+
dynamic_ncols=True,
|
397 |
+
smoothing=0.04,
|
398 |
+
disable=not self.accelerator.is_main_process,
|
399 |
+
):
|
400 |
+
batch_loss = self._valid_step(batch)
|
401 |
+
epoch_sum_loss += batch_loss.item()
|
402 |
+
|
403 |
+
self.accelerator.wait_for_everyone()
|
404 |
+
return epoch_sum_loss / len(self.valid_dataloader)
|
405 |
+
|
406 |
+
def _train_step(self, batch):
|
407 |
+
r"""Training forward step. Should return average loss of a sample over
|
408 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
409 |
+
See ``_train_epoch`` for usage.
|
410 |
+
"""
|
411 |
+
return self._forward_step(batch)
|
412 |
+
|
413 |
+
@torch.inference_mode()
|
414 |
+
def _valid_step(self, batch):
|
415 |
+
r"""Testing forward step. Should return average loss of a sample over
|
416 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
417 |
+
See ``_test_epoch`` for usage.
|
418 |
+
"""
|
419 |
+
return self._forward_step(batch)
|
420 |
+
|
421 |
+
def __load_model(
|
422 |
+
self,
|
423 |
+
checkpoint_dir: str = None,
|
424 |
+
checkpoint_path: str = None,
|
425 |
+
resume_type: str = "",
|
426 |
+
):
|
427 |
+
r"""Load model from checkpoint. If checkpoint_path is None, it will
|
428 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
429 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
430 |
+
method after** ``accelerator.prepare()``.
|
431 |
+
"""
|
432 |
+
if checkpoint_path is None:
|
433 |
+
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
|
434 |
+
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
|
435 |
+
checkpoint_path = ls[0]
|
436 |
+
self.logger.info("Resume from {}...".format(checkpoint_path))
|
437 |
+
|
438 |
+
if resume_type in ["resume", ""]:
|
439 |
+
# Load all the things, including model weights, optimizer, scheduler, and random states.
|
440 |
+
self.accelerator.load_state(input_dir=checkpoint_path)
|
441 |
+
|
442 |
+
# set epoch and step
|
443 |
+
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
|
444 |
+
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
|
445 |
+
|
446 |
+
elif resume_type == "finetune":
|
447 |
+
# Load only the model weights
|
448 |
+
accelerate.load_checkpoint_and_dispatch(
|
449 |
+
self.accelerator.unwrap_model(self.model),
|
450 |
+
os.path.join(checkpoint_path, "pytorch_model.bin"),
|
451 |
+
)
|
452 |
+
self.logger.info("Load model weights for finetune...")
|
453 |
+
|
454 |
+
else:
|
455 |
+
raise ValueError("Resume_type must be `resume` or `finetune`.")
|
456 |
+
|
457 |
+
return checkpoint_path
|
458 |
+
|
459 |
+
# TODO: LEGACY CODE
|
460 |
+
def _build_dataloader(self):
|
461 |
+
Dataset, Collator = self._build_dataset()
|
462 |
+
|
463 |
+
# build dataset instance for each dataset and combine them by ConcatDataset
|
464 |
+
datasets_list = []
|
465 |
+
for dataset in self.cfg.dataset:
|
466 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
467 |
+
datasets_list.append(subdataset)
|
468 |
+
train_dataset = ConcatDataset(datasets_list)
|
469 |
+
train_collate = Collator(self.cfg)
|
470 |
+
_, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
|
471 |
+
self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
|
472 |
+
self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
|
473 |
+
# TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
|
474 |
+
train_loader = DataLoader(
|
475 |
+
train_dataset,
|
476 |
+
collate_fn=train_collate,
|
477 |
+
batch_sampler=batch_sampler,
|
478 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
479 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
480 |
+
)
|
481 |
+
|
482 |
+
# Build valid dataloader
|
483 |
+
datasets_list = []
|
484 |
+
for dataset in self.cfg.dataset:
|
485 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
486 |
+
datasets_list.append(subdataset)
|
487 |
+
valid_dataset = ConcatDataset(datasets_list)
|
488 |
+
valid_collate = Collator(self.cfg)
|
489 |
+
_, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
|
490 |
+
self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
|
491 |
+
self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
|
492 |
+
valid_loader = DataLoader(
|
493 |
+
valid_dataset,
|
494 |
+
collate_fn=valid_collate,
|
495 |
+
batch_sampler=batch_sampler,
|
496 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
497 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
498 |
+
)
|
499 |
+
return train_loader, valid_loader
|
500 |
+
|
501 |
+
@staticmethod
|
502 |
+
def _set_random_seed(seed):
|
503 |
+
r"""Set random seed for all possible random modules."""
|
504 |
+
random.seed(seed)
|
505 |
+
np.random.seed(seed)
|
506 |
+
torch.random.manual_seed(seed)
|
507 |
+
|
508 |
+
def _check_nan(self, loss, y_pred, y_gt):
|
509 |
+
if torch.any(torch.isnan(loss)):
|
510 |
+
self.logger.fatal("Fatal Error: Training is down since loss has Nan!")
|
511 |
+
self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
|
512 |
+
if torch.any(torch.isnan(y_pred)):
|
513 |
+
self.logger.error(
|
514 |
+
f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
|
515 |
+
)
|
516 |
+
else:
|
517 |
+
self.logger.debug(
|
518 |
+
f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
|
519 |
+
)
|
520 |
+
if torch.any(torch.isnan(y_gt)):
|
521 |
+
self.logger.error(
|
522 |
+
f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
|
523 |
+
)
|
524 |
+
else:
|
525 |
+
self.logger.debug(
|
526 |
+
f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
|
527 |
+
)
|
528 |
+
if torch.any(torch.isnan(y_pred)):
|
529 |
+
self.logger.error(f"y_pred: {y_pred}", in_order=True)
|
530 |
+
else:
|
531 |
+
self.logger.debug(f"y_pred: {y_pred}", in_order=True)
|
532 |
+
if torch.any(torch.isnan(y_gt)):
|
533 |
+
self.logger.error(f"y_gt: {y_gt}", in_order=True)
|
534 |
+
else:
|
535 |
+
self.logger.debug(f"y_gt: {y_gt}", in_order=True)
|
536 |
+
|
537 |
+
# TODO: still OK to save tracking?
|
538 |
+
self.accelerator.end_training()
|
539 |
+
raise RuntimeError("Loss has Nan! See log for more info.")
|
540 |
+
|
541 |
+
### Protected methods end ###
|
542 |
+
|
543 |
+
## Following are private methods ##
|
544 |
+
## !!! These are inconvenient for GAN-based model training. It'd be better to move these to svc_trainer.py if needed.
|
545 |
+
def __build_optimizer(self):
|
546 |
+
r"""Build optimizer for model."""
|
547 |
+
# Make case-insensitive matching
|
548 |
+
if self.cfg.train.optimizer.lower() == "adadelta":
|
549 |
+
optimizer = torch.optim.Adadelta(
|
550 |
+
self.model.parameters(), **self.cfg.train.adadelta
|
551 |
+
)
|
552 |
+
self.logger.info("Using Adadelta optimizer.")
|
553 |
+
elif self.cfg.train.optimizer.lower() == "adagrad":
|
554 |
+
optimizer = torch.optim.Adagrad(
|
555 |
+
self.model.parameters(), **self.cfg.train.adagrad
|
556 |
+
)
|
557 |
+
self.logger.info("Using Adagrad optimizer.")
|
558 |
+
elif self.cfg.train.optimizer.lower() == "adam":
|
559 |
+
optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
|
560 |
+
self.logger.info("Using Adam optimizer.")
|
561 |
+
elif self.cfg.train.optimizer.lower() == "adamw":
|
562 |
+
optimizer = torch.optim.AdamW(
|
563 |
+
self.model.parameters(), **self.cfg.train.adamw
|
564 |
+
)
|
565 |
+
elif self.cfg.train.optimizer.lower() == "sparseadam":
|
566 |
+
optimizer = torch.optim.SparseAdam(
|
567 |
+
self.model.parameters(), **self.cfg.train.sparseadam
|
568 |
+
)
|
569 |
+
elif self.cfg.train.optimizer.lower() == "adamax":
|
570 |
+
optimizer = torch.optim.Adamax(
|
571 |
+
self.model.parameters(), **self.cfg.train.adamax
|
572 |
+
)
|
573 |
+
elif self.cfg.train.optimizer.lower() == "asgd":
|
574 |
+
optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
|
575 |
+
elif self.cfg.train.optimizer.lower() == "lbfgs":
|
576 |
+
optimizer = torch.optim.LBFGS(
|
577 |
+
self.model.parameters(), **self.cfg.train.lbfgs
|
578 |
+
)
|
579 |
+
elif self.cfg.train.optimizer.lower() == "nadam":
|
580 |
+
optimizer = torch.optim.NAdam(
|
581 |
+
self.model.parameters(), **self.cfg.train.nadam
|
582 |
+
)
|
583 |
+
elif self.cfg.train.optimizer.lower() == "radam":
|
584 |
+
optimizer = torch.optim.RAdam(
|
585 |
+
self.model.parameters(), **self.cfg.train.radam
|
586 |
+
)
|
587 |
+
elif self.cfg.train.optimizer.lower() == "rmsprop":
|
588 |
+
optimizer = torch.optim.RMSprop(
|
589 |
+
self.model.parameters(), **self.cfg.train.rmsprop
|
590 |
+
)
|
591 |
+
elif self.cfg.train.optimizer.lower() == "rprop":
|
592 |
+
optimizer = torch.optim.Rprop(
|
593 |
+
self.model.parameters(), **self.cfg.train.rprop
|
594 |
+
)
|
595 |
+
elif self.cfg.train.optimizer.lower() == "sgd":
|
596 |
+
optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
|
597 |
+
else:
|
598 |
+
raise NotImplementedError(
|
599 |
+
f"Optimizer {self.cfg.train.optimizer} not supported yet!"
|
600 |
+
)
|
601 |
+
return optimizer
|
602 |
+
|
603 |
+
def __build_scheduler(self):
|
604 |
+
r"""Build scheduler for optimizer."""
|
605 |
+
# Make case-insensitive matching
|
606 |
+
if self.cfg.train.scheduler.lower() == "lambdalr":
|
607 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
608 |
+
self.optimizer, **self.cfg.train.lambdalr
|
609 |
+
)
|
610 |
+
elif self.cfg.train.scheduler.lower() == "multiplicativelr":
|
611 |
+
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
|
612 |
+
self.optimizer, **self.cfg.train.multiplicativelr
|
613 |
+
)
|
614 |
+
elif self.cfg.train.scheduler.lower() == "steplr":
|
615 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
616 |
+
self.optimizer, **self.cfg.train.steplr
|
617 |
+
)
|
618 |
+
elif self.cfg.train.scheduler.lower() == "multisteplr":
|
619 |
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
620 |
+
self.optimizer, **self.cfg.train.multisteplr
|
621 |
+
)
|
622 |
+
elif self.cfg.train.scheduler.lower() == "constantlr":
|
623 |
+
scheduler = torch.optim.lr_scheduler.ConstantLR(
|
624 |
+
self.optimizer, **self.cfg.train.constantlr
|
625 |
+
)
|
626 |
+
elif self.cfg.train.scheduler.lower() == "linearlr":
|
627 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(
|
628 |
+
self.optimizer, **self.cfg.train.linearlr
|
629 |
+
)
|
630 |
+
elif self.cfg.train.scheduler.lower() == "exponentiallr":
|
631 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
632 |
+
self.optimizer, **self.cfg.train.exponentiallr
|
633 |
+
)
|
634 |
+
elif self.cfg.train.scheduler.lower() == "polynomiallr":
|
635 |
+
scheduler = torch.optim.lr_scheduler.PolynomialLR(
|
636 |
+
self.optimizer, **self.cfg.train.polynomiallr
|
637 |
+
)
|
638 |
+
elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
|
639 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
640 |
+
self.optimizer, **self.cfg.train.cosineannealinglr
|
641 |
+
)
|
642 |
+
elif self.cfg.train.scheduler.lower() == "sequentiallr":
|
643 |
+
scheduler = torch.optim.lr_scheduler.SequentialLR(
|
644 |
+
self.optimizer, **self.cfg.train.sequentiallr
|
645 |
+
)
|
646 |
+
elif self.cfg.train.scheduler.lower() == "reducelronplateau":
|
647 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
648 |
+
self.optimizer, **self.cfg.train.reducelronplateau
|
649 |
+
)
|
650 |
+
elif self.cfg.train.scheduler.lower() == "cycliclr":
|
651 |
+
scheduler = torch.optim.lr_scheduler.CyclicLR(
|
652 |
+
self.optimizer, **self.cfg.train.cycliclr
|
653 |
+
)
|
654 |
+
elif self.cfg.train.scheduler.lower() == "onecyclelr":
|
655 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
656 |
+
self.optimizer, **self.cfg.train.onecyclelr
|
657 |
+
)
|
658 |
+
elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
|
659 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
660 |
+
self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
|
661 |
+
)
|
662 |
+
elif self.cfg.train.scheduler.lower() == "noamlr":
|
663 |
+
scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
|
664 |
+
else:
|
665 |
+
raise NotImplementedError(
|
666 |
+
f"Scheduler {self.cfg.train.scheduler} not supported yet!"
|
667 |
+
)
|
668 |
+
return scheduler
|
669 |
+
|
670 |
+
def _init_accelerator(self):
|
671 |
+
self.exp_dir = os.path.join(
|
672 |
+
os.path.abspath(self.cfg.log_dir), self.args.exp_name
|
673 |
+
)
|
674 |
+
project_config = ProjectConfiguration(
|
675 |
+
project_dir=self.exp_dir,
|
676 |
+
logging_dir=os.path.join(self.exp_dir, "log"),
|
677 |
+
)
|
678 |
+
self.accelerator = accelerate.Accelerator(
|
679 |
+
gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
|
680 |
+
log_with=self.cfg.train.tracker,
|
681 |
+
project_config=project_config,
|
682 |
+
)
|
683 |
+
if self.accelerator.is_main_process:
|
684 |
+
os.makedirs(project_config.project_dir, exist_ok=True)
|
685 |
+
os.makedirs(project_config.logging_dir, exist_ok=True)
|
686 |
+
with self.accelerator.main_process_first():
|
687 |
+
self.accelerator.init_trackers(self.args.exp_name)
|
688 |
+
|
689 |
+
def __check_basic_configs(self):
|
690 |
+
if self.cfg.train.gradient_accumulation_step <= 0:
|
691 |
+
self.logger.fatal("Invalid gradient_accumulation_step value!")
|
692 |
+
self.logger.error(
|
693 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
694 |
+
)
|
695 |
+
self.accelerator.end_training()
|
696 |
+
raise ValueError(
|
697 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
698 |
+
)
|
699 |
+
# TODO: check other values
|
700 |
+
|
701 |
+
@staticmethod
|
702 |
+
def __count_parameters(model):
|
703 |
+
model_param = 0.0
|
704 |
+
if isinstance(model, dict):
|
705 |
+
for key, value in model.items():
|
706 |
+
model_param += sum(p.numel() for p in model[key].parameters())
|
707 |
+
else:
|
708 |
+
model_param = sum(p.numel() for p in model.parameters())
|
709 |
+
return model_param
|
710 |
+
|
711 |
+
def __dump_cfg(self, path):
|
712 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
713 |
+
json5.dump(
|
714 |
+
self.cfg,
|
715 |
+
open(path, "w"),
|
716 |
+
indent=4,
|
717 |
+
sort_keys=True,
|
718 |
+
ensure_ascii=False,
|
719 |
+
quote_keys=True,
|
720 |
+
)
|
721 |
+
|
722 |
+
### Private methods end ###
|
models/svc/__init__.py
ADDED
File without changes
|
models/svc/base/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .svc_inference import SVCInference
|
7 |
+
from .svc_trainer import SVCTrainer
|
models/svc/base/svc_dataset.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import random
|
7 |
+
import torch
|
8 |
+
from torch.nn.utils.rnn import pad_sequence
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
from utils.data_utils import *
|
13 |
+
from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema
|
14 |
+
from processors.content_extractor import (
|
15 |
+
ContentvecExtractor,
|
16 |
+
WhisperExtractor,
|
17 |
+
WenetExtractor,
|
18 |
+
)
|
19 |
+
from models.base.base_dataset import (
|
20 |
+
BaseCollator,
|
21 |
+
BaseDataset,
|
22 |
+
)
|
23 |
+
from models.base.new_dataset import BaseTestDataset
|
24 |
+
|
25 |
+
EPS = 1.0e-12
|
26 |
+
|
27 |
+
|
28 |
+
class SVCDataset(BaseDataset):
|
29 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
30 |
+
BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid)
|
31 |
+
|
32 |
+
cfg = self.cfg
|
33 |
+
|
34 |
+
if cfg.model.condition_encoder.use_whisper:
|
35 |
+
self.whisper_aligner = WhisperExtractor(self.cfg)
|
36 |
+
self.utt2whisper_path = load_content_feature_path(
|
37 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
|
38 |
+
)
|
39 |
+
|
40 |
+
if cfg.model.condition_encoder.use_contentvec:
|
41 |
+
self.contentvec_aligner = ContentvecExtractor(self.cfg)
|
42 |
+
self.utt2contentVec_path = load_content_feature_path(
|
43 |
+
self.metadata,
|
44 |
+
cfg.preprocess.processed_dir,
|
45 |
+
cfg.preprocess.contentvec_dir,
|
46 |
+
)
|
47 |
+
|
48 |
+
if cfg.model.condition_encoder.use_mert:
|
49 |
+
self.utt2mert_path = load_content_feature_path(
|
50 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
|
51 |
+
)
|
52 |
+
if cfg.model.condition_encoder.use_wenet:
|
53 |
+
self.wenet_aligner = WenetExtractor(self.cfg)
|
54 |
+
self.utt2wenet_path = load_content_feature_path(
|
55 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
|
56 |
+
)
|
57 |
+
|
58 |
+
def __getitem__(self, index):
|
59 |
+
single_feature = BaseDataset.__getitem__(self, index)
|
60 |
+
|
61 |
+
utt_info = self.metadata[index]
|
62 |
+
dataset = utt_info["Dataset"]
|
63 |
+
uid = utt_info["Uid"]
|
64 |
+
utt = "{}_{}".format(dataset, uid)
|
65 |
+
|
66 |
+
if self.cfg.model.condition_encoder.use_whisper:
|
67 |
+
assert "target_len" in single_feature.keys()
|
68 |
+
aligned_whisper_feat = self.whisper_aligner.offline_align(
|
69 |
+
np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
|
70 |
+
)
|
71 |
+
single_feature["whisper_feat"] = aligned_whisper_feat
|
72 |
+
|
73 |
+
if self.cfg.model.condition_encoder.use_contentvec:
|
74 |
+
assert "target_len" in single_feature.keys()
|
75 |
+
aligned_contentvec = self.contentvec_aligner.offline_align(
|
76 |
+
np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
|
77 |
+
)
|
78 |
+
single_feature["contentvec_feat"] = aligned_contentvec
|
79 |
+
|
80 |
+
if self.cfg.model.condition_encoder.use_mert:
|
81 |
+
assert "target_len" in single_feature.keys()
|
82 |
+
aligned_mert_feat = align_content_feature_length(
|
83 |
+
np.load(self.utt2mert_path[utt]),
|
84 |
+
single_feature["target_len"],
|
85 |
+
source_hop=self.cfg.preprocess.mert_hop_size,
|
86 |
+
)
|
87 |
+
single_feature["mert_feat"] = aligned_mert_feat
|
88 |
+
|
89 |
+
if self.cfg.model.condition_encoder.use_wenet:
|
90 |
+
assert "target_len" in single_feature.keys()
|
91 |
+
aligned_wenet_feat = self.wenet_aligner.offline_align(
|
92 |
+
np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
|
93 |
+
)
|
94 |
+
single_feature["wenet_feat"] = aligned_wenet_feat
|
95 |
+
|
96 |
+
# print(single_feature.keys())
|
97 |
+
# for k, v in single_feature.items():
|
98 |
+
# if type(v) in [torch.Tensor, np.ndarray]:
|
99 |
+
# print(k, v.shape)
|
100 |
+
# else:
|
101 |
+
# print(k, v)
|
102 |
+
# exit()
|
103 |
+
|
104 |
+
return self.clip_if_too_long(single_feature)
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.metadata)
|
108 |
+
|
109 |
+
def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812):
|
110 |
+
"""
|
111 |
+
ending_ts: to avoid invalid whisper features for over 30s audios
|
112 |
+
2812 = 30 * 24000 // 256
|
113 |
+
"""
|
114 |
+
ts = max(feature_seq_len - max_seq_len, 0)
|
115 |
+
ts = min(ts, ending_ts - max_seq_len)
|
116 |
+
|
117 |
+
start = random.randint(0, ts)
|
118 |
+
end = start + max_seq_len
|
119 |
+
return start, end
|
120 |
+
|
121 |
+
def clip_if_too_long(self, sample, max_seq_len=512):
|
122 |
+
"""
|
123 |
+
sample :
|
124 |
+
{
|
125 |
+
'spk_id': (1,),
|
126 |
+
'target_len': int
|
127 |
+
'mel': (seq_len, dim),
|
128 |
+
'frame_pitch': (seq_len,)
|
129 |
+
'frame_energy': (seq_len,)
|
130 |
+
'content_vector_feat': (seq_len, dim)
|
131 |
+
}
|
132 |
+
"""
|
133 |
+
|
134 |
+
if sample["target_len"] <= max_seq_len:
|
135 |
+
return sample
|
136 |
+
|
137 |
+
start, end = self.random_select(sample["target_len"], max_seq_len)
|
138 |
+
sample["target_len"] = end - start
|
139 |
+
|
140 |
+
for k in sample.keys():
|
141 |
+
if k == "audio":
|
142 |
+
# audio should be clipped in hop_size scale
|
143 |
+
sample[k] = sample[k][
|
144 |
+
start
|
145 |
+
* self.cfg.preprocess.hop_size : end
|
146 |
+
* self.cfg.preprocess.hop_size
|
147 |
+
]
|
148 |
+
elif k == "audio_len":
|
149 |
+
sample[k] = (end - start) * self.cfg.preprocess.hop_size
|
150 |
+
elif k not in ["spk_id", "target_len"]:
|
151 |
+
sample[k] = sample[k][start:end]
|
152 |
+
|
153 |
+
return sample
|
154 |
+
|
155 |
+
|
156 |
+
class SVCCollator(BaseCollator):
|
157 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
158 |
+
|
159 |
+
def __init__(self, cfg):
|
160 |
+
BaseCollator.__init__(self, cfg)
|
161 |
+
|
162 |
+
def __call__(self, batch):
|
163 |
+
parsed_batch_features = BaseCollator.__call__(self, batch)
|
164 |
+
return parsed_batch_features
|
165 |
+
|
166 |
+
|
167 |
+
class SVCTestDataset(BaseTestDataset):
|
168 |
+
def __init__(self, args, cfg, infer_type):
|
169 |
+
BaseTestDataset.__init__(self, args, cfg, infer_type)
|
170 |
+
self.metadata = self.get_metadata()
|
171 |
+
|
172 |
+
target_singer = args.target_singer
|
173 |
+
self.cfg = cfg
|
174 |
+
self.trans_key = args.trans_key
|
175 |
+
assert type(target_singer) == str
|
176 |
+
|
177 |
+
self.target_singer = target_singer.split("_")[-1]
|
178 |
+
self.target_dataset = target_singer.replace(
|
179 |
+
"_{}".format(self.target_singer), ""
|
180 |
+
)
|
181 |
+
if cfg.preprocess.mel_min_max_norm:
|
182 |
+
self.target_mel_extrema = load_mel_extrema(
|
183 |
+
cfg.preprocess, self.target_dataset
|
184 |
+
)
|
185 |
+
self.target_mel_extrema = torch.as_tensor(
|
186 |
+
self.target_mel_extrema[0]
|
187 |
+
), torch.as_tensor(self.target_mel_extrema[1])
|
188 |
+
|
189 |
+
######### Load source acoustic features #########
|
190 |
+
if cfg.preprocess.use_spkid:
|
191 |
+
spk2id_path = os.path.join(args.acoustics_dir, cfg.preprocess.spk2id)
|
192 |
+
# utt2sp_path = os.path.join(self.data_root, cfg.preprocess.utt2spk)
|
193 |
+
|
194 |
+
with open(spk2id_path, "r") as f:
|
195 |
+
self.spk2id = json.load(f)
|
196 |
+
# print("self.spk2id", self.spk2id)
|
197 |
+
|
198 |
+
if cfg.preprocess.use_uv:
|
199 |
+
self.utt2uv_path = {
|
200 |
+
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
|
201 |
+
cfg.preprocess.processed_dir,
|
202 |
+
utt_info["Dataset"],
|
203 |
+
cfg.preprocess.uv_dir,
|
204 |
+
utt_info["Uid"] + ".npy",
|
205 |
+
)
|
206 |
+
for utt_info in self.metadata
|
207 |
+
}
|
208 |
+
|
209 |
+
if cfg.preprocess.use_frame_pitch:
|
210 |
+
self.utt2frame_pitch_path = {
|
211 |
+
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
|
212 |
+
cfg.preprocess.processed_dir,
|
213 |
+
utt_info["Dataset"],
|
214 |
+
cfg.preprocess.pitch_dir,
|
215 |
+
utt_info["Uid"] + ".npy",
|
216 |
+
)
|
217 |
+
for utt_info in self.metadata
|
218 |
+
}
|
219 |
+
|
220 |
+
# Target F0 median
|
221 |
+
target_f0_statistics_path = os.path.join(
|
222 |
+
cfg.preprocess.processed_dir,
|
223 |
+
self.target_dataset,
|
224 |
+
cfg.preprocess.pitch_dir,
|
225 |
+
"statistics.json",
|
226 |
+
)
|
227 |
+
self.target_pitch_median = json.load(open(target_f0_statistics_path, "r"))[
|
228 |
+
f"{self.target_dataset}_{self.target_singer}"
|
229 |
+
]["voiced_positions"]["median"]
|
230 |
+
|
231 |
+
# Source F0 median (if infer from file)
|
232 |
+
if infer_type == "from_file":
|
233 |
+
source_audio_name = cfg.inference.source_audio_name
|
234 |
+
source_f0_statistics_path = os.path.join(
|
235 |
+
cfg.preprocess.processed_dir,
|
236 |
+
source_audio_name,
|
237 |
+
cfg.preprocess.pitch_dir,
|
238 |
+
"statistics.json",
|
239 |
+
)
|
240 |
+
self.source_pitch_median = json.load(
|
241 |
+
open(source_f0_statistics_path, "r")
|
242 |
+
)[f"{source_audio_name}_{source_audio_name}"]["voiced_positions"][
|
243 |
+
"median"
|
244 |
+
]
|
245 |
+
else:
|
246 |
+
self.source_pitch_median = None
|
247 |
+
|
248 |
+
if cfg.preprocess.use_frame_energy:
|
249 |
+
self.utt2frame_energy_path = {
|
250 |
+
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
|
251 |
+
cfg.preprocess.processed_dir,
|
252 |
+
utt_info["Dataset"],
|
253 |
+
cfg.preprocess.energy_dir,
|
254 |
+
utt_info["Uid"] + ".npy",
|
255 |
+
)
|
256 |
+
for utt_info in self.metadata
|
257 |
+
}
|
258 |
+
|
259 |
+
if cfg.preprocess.use_mel:
|
260 |
+
self.utt2mel_path = {
|
261 |
+
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
|
262 |
+
cfg.preprocess.processed_dir,
|
263 |
+
utt_info["Dataset"],
|
264 |
+
cfg.preprocess.mel_dir,
|
265 |
+
utt_info["Uid"] + ".npy",
|
266 |
+
)
|
267 |
+
for utt_info in self.metadata
|
268 |
+
}
|
269 |
+
|
270 |
+
######### Load source content features' path #########
|
271 |
+
if cfg.model.condition_encoder.use_whisper:
|
272 |
+
self.whisper_aligner = WhisperExtractor(cfg)
|
273 |
+
self.utt2whisper_path = load_content_feature_path(
|
274 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
|
275 |
+
)
|
276 |
+
|
277 |
+
if cfg.model.condition_encoder.use_contentvec:
|
278 |
+
self.contentvec_aligner = ContentvecExtractor(cfg)
|
279 |
+
self.utt2contentVec_path = load_content_feature_path(
|
280 |
+
self.metadata,
|
281 |
+
cfg.preprocess.processed_dir,
|
282 |
+
cfg.preprocess.contentvec_dir,
|
283 |
+
)
|
284 |
+
|
285 |
+
if cfg.model.condition_encoder.use_mert:
|
286 |
+
self.utt2mert_path = load_content_feature_path(
|
287 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
|
288 |
+
)
|
289 |
+
if cfg.model.condition_encoder.use_wenet:
|
290 |
+
self.wenet_aligner = WenetExtractor(cfg)
|
291 |
+
self.utt2wenet_path = load_content_feature_path(
|
292 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
|
293 |
+
)
|
294 |
+
|
295 |
+
def __getitem__(self, index):
|
296 |
+
single_feature = {}
|
297 |
+
|
298 |
+
utt_info = self.metadata[index]
|
299 |
+
dataset = utt_info["Dataset"]
|
300 |
+
uid = utt_info["Uid"]
|
301 |
+
utt = "{}_{}".format(dataset, uid)
|
302 |
+
|
303 |
+
source_dataset = self.metadata[index]["Dataset"]
|
304 |
+
|
305 |
+
if self.cfg.preprocess.use_spkid:
|
306 |
+
single_feature["spk_id"] = np.array(
|
307 |
+
[self.spk2id[f"{self.target_dataset}_{self.target_singer}"]],
|
308 |
+
dtype=np.int32,
|
309 |
+
)
|
310 |
+
|
311 |
+
######### Get Acoustic Features Item #########
|
312 |
+
if self.cfg.preprocess.use_mel:
|
313 |
+
mel = np.load(self.utt2mel_path[utt])
|
314 |
+
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
|
315 |
+
if self.cfg.preprocess.use_min_max_norm_mel:
|
316 |
+
# mel norm
|
317 |
+
mel = cal_normalized_mel(mel, source_dataset, self.cfg.preprocess)
|
318 |
+
|
319 |
+
if "target_len" not in single_feature.keys():
|
320 |
+
single_feature["target_len"] = mel.shape[1]
|
321 |
+
single_feature["mel"] = mel.T # [T, n_mels]
|
322 |
+
|
323 |
+
if self.cfg.preprocess.use_frame_pitch:
|
324 |
+
frame_pitch_path = self.utt2frame_pitch_path[utt]
|
325 |
+
frame_pitch = np.load(frame_pitch_path)
|
326 |
+
|
327 |
+
if self.trans_key:
|
328 |
+
try:
|
329 |
+
self.trans_key = int(self.trans_key)
|
330 |
+
except:
|
331 |
+
pass
|
332 |
+
if type(self.trans_key) == int:
|
333 |
+
frame_pitch = transpose_key(frame_pitch, self.trans_key)
|
334 |
+
elif self.trans_key:
|
335 |
+
assert self.target_singer
|
336 |
+
|
337 |
+
frame_pitch = pitch_shift_to_target(
|
338 |
+
frame_pitch, self.target_pitch_median, self.source_pitch_median
|
339 |
+
)
|
340 |
+
|
341 |
+
if "target_len" not in single_feature.keys():
|
342 |
+
single_feature["target_len"] = len(frame_pitch)
|
343 |
+
aligned_frame_pitch = align_length(
|
344 |
+
frame_pitch, single_feature["target_len"]
|
345 |
+
)
|
346 |
+
single_feature["frame_pitch"] = aligned_frame_pitch
|
347 |
+
|
348 |
+
if self.cfg.preprocess.use_uv:
|
349 |
+
frame_uv_path = self.utt2uv_path[utt]
|
350 |
+
frame_uv = np.load(frame_uv_path)
|
351 |
+
aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
|
352 |
+
aligned_frame_uv = [
|
353 |
+
0 if frame_uv else 1 for frame_uv in aligned_frame_uv
|
354 |
+
]
|
355 |
+
aligned_frame_uv = np.array(aligned_frame_uv)
|
356 |
+
single_feature["frame_uv"] = aligned_frame_uv
|
357 |
+
|
358 |
+
if self.cfg.preprocess.use_frame_energy:
|
359 |
+
frame_energy_path = self.utt2frame_energy_path[utt]
|
360 |
+
frame_energy = np.load(frame_energy_path)
|
361 |
+
if "target_len" not in single_feature.keys():
|
362 |
+
single_feature["target_len"] = len(frame_energy)
|
363 |
+
aligned_frame_energy = align_length(
|
364 |
+
frame_energy, single_feature["target_len"]
|
365 |
+
)
|
366 |
+
single_feature["frame_energy"] = aligned_frame_energy
|
367 |
+
|
368 |
+
######### Get Content Features Item #########
|
369 |
+
if self.cfg.model.condition_encoder.use_whisper:
|
370 |
+
assert "target_len" in single_feature.keys()
|
371 |
+
aligned_whisper_feat = self.whisper_aligner.offline_align(
|
372 |
+
np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
|
373 |
+
)
|
374 |
+
single_feature["whisper_feat"] = aligned_whisper_feat
|
375 |
+
|
376 |
+
if self.cfg.model.condition_encoder.use_contentvec:
|
377 |
+
assert "target_len" in single_feature.keys()
|
378 |
+
aligned_contentvec = self.contentvec_aligner.offline_align(
|
379 |
+
np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
|
380 |
+
)
|
381 |
+
single_feature["contentvec_feat"] = aligned_contentvec
|
382 |
+
|
383 |
+
if self.cfg.model.condition_encoder.use_mert:
|
384 |
+
assert "target_len" in single_feature.keys()
|
385 |
+
aligned_mert_feat = align_content_feature_length(
|
386 |
+
np.load(self.utt2mert_path[utt]),
|
387 |
+
single_feature["target_len"],
|
388 |
+
source_hop=self.cfg.preprocess.mert_hop_size,
|
389 |
+
)
|
390 |
+
single_feature["mert_feat"] = aligned_mert_feat
|
391 |
+
|
392 |
+
if self.cfg.model.condition_encoder.use_wenet:
|
393 |
+
assert "target_len" in single_feature.keys()
|
394 |
+
aligned_wenet_feat = self.wenet_aligner.offline_align(
|
395 |
+
np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
|
396 |
+
)
|
397 |
+
single_feature["wenet_feat"] = aligned_wenet_feat
|
398 |
+
|
399 |
+
return single_feature
|
400 |
+
|
401 |
+
def __len__(self):
|
402 |
+
return len(self.metadata)
|
403 |
+
|
404 |
+
|
405 |
+
class SVCTestCollator:
|
406 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
407 |
+
|
408 |
+
def __init__(self, cfg):
|
409 |
+
self.cfg = cfg
|
410 |
+
|
411 |
+
def __call__(self, batch):
|
412 |
+
packed_batch_features = dict()
|
413 |
+
|
414 |
+
# mel: [b, T, n_mels]
|
415 |
+
# frame_pitch, frame_energy: [1, T]
|
416 |
+
# target_len: [1]
|
417 |
+
# spk_id: [b, 1]
|
418 |
+
# mask: [b, T, 1]
|
419 |
+
|
420 |
+
for key in batch[0].keys():
|
421 |
+
if key == "target_len":
|
422 |
+
packed_batch_features["target_len"] = torch.LongTensor(
|
423 |
+
[b["target_len"] for b in batch]
|
424 |
+
)
|
425 |
+
masks = [
|
426 |
+
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
|
427 |
+
]
|
428 |
+
packed_batch_features["mask"] = pad_sequence(
|
429 |
+
masks, batch_first=True, padding_value=0
|
430 |
+
)
|
431 |
+
else:
|
432 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
433 |
+
packed_batch_features[key] = pad_sequence(
|
434 |
+
values, batch_first=True, padding_value=0
|
435 |
+
)
|
436 |
+
|
437 |
+
return packed_batch_features
|
models/svc/base/svc_inference.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from models.base.new_inference import BaseInference
|
7 |
+
from models.svc.base.svc_dataset import SVCTestCollator, SVCTestDataset
|
8 |
+
|
9 |
+
|
10 |
+
class SVCInference(BaseInference):
|
11 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
12 |
+
BaseInference.__init__(self, args, cfg, infer_type)
|
13 |
+
|
14 |
+
def _build_test_dataset(self):
|
15 |
+
return SVCTestDataset, SVCTestCollator
|
models/svc/base/svc_trainer.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from models.base.new_trainer import BaseTrainer
|
13 |
+
from models.svc.base.svc_dataset import SVCCollator, SVCDataset
|
14 |
+
|
15 |
+
|
16 |
+
class SVCTrainer(BaseTrainer):
|
17 |
+
r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements
|
18 |
+
``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
|
19 |
+
class, and implement ``_build_model``, ``_forward_step``.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, args=None, cfg=None):
|
23 |
+
self.args = args
|
24 |
+
self.cfg = cfg
|
25 |
+
|
26 |
+
self._init_accelerator()
|
27 |
+
|
28 |
+
# Only for SVC tasks
|
29 |
+
with self.accelerator.main_process_first():
|
30 |
+
self.singers = self._build_singer_lut()
|
31 |
+
|
32 |
+
# Super init
|
33 |
+
BaseTrainer.__init__(self, args, cfg)
|
34 |
+
|
35 |
+
# Only for SVC tasks
|
36 |
+
self.task_type = "SVC"
|
37 |
+
self.logger.info("Task type: {}".format(self.task_type))
|
38 |
+
|
39 |
+
### Following are methods only for SVC tasks ###
|
40 |
+
# TODO: LEGACY CODE, NEED TO BE REFACTORED
|
41 |
+
def _build_dataset(self):
|
42 |
+
return SVCDataset, SVCCollator
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def _build_criterion():
|
46 |
+
criterion = nn.MSELoss(reduction="none")
|
47 |
+
return criterion
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def _compute_loss(criterion, y_pred, y_gt, loss_mask):
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
criterion: MSELoss(reduction='none')
|
54 |
+
y_pred, y_gt: (bs, seq_len, D)
|
55 |
+
loss_mask: (bs, seq_len, 1)
|
56 |
+
Returns:
|
57 |
+
loss: Tensor of shape []
|
58 |
+
"""
|
59 |
+
|
60 |
+
# (bs, seq_len, D)
|
61 |
+
loss = criterion(y_pred, y_gt)
|
62 |
+
# expand loss_mask to (bs, seq_len, D)
|
63 |
+
loss_mask = loss_mask.repeat(1, 1, loss.shape[-1])
|
64 |
+
|
65 |
+
loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask)
|
66 |
+
return loss
|
67 |
+
|
68 |
+
def _save_auxiliary_states(self):
|
69 |
+
"""
|
70 |
+
To save the singer's look-up table in the checkpoint saving path
|
71 |
+
"""
|
72 |
+
with open(
|
73 |
+
os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), "w"
|
74 |
+
) as f:
|
75 |
+
json.dump(self.singers, f, indent=4, ensure_ascii=False)
|
76 |
+
|
77 |
+
def _build_singer_lut(self):
|
78 |
+
resumed_singer_path = None
|
79 |
+
if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "":
|
80 |
+
resumed_singer_path = os.path.join(
|
81 |
+
self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id
|
82 |
+
)
|
83 |
+
if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
|
84 |
+
resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
|
85 |
+
|
86 |
+
if resumed_singer_path:
|
87 |
+
with open(resumed_singer_path, "r") as f:
|
88 |
+
singers = json.load(f)
|
89 |
+
else:
|
90 |
+
singers = dict()
|
91 |
+
|
92 |
+
for dataset in self.cfg.dataset:
|
93 |
+
singer_lut_path = os.path.join(
|
94 |
+
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
|
95 |
+
)
|
96 |
+
with open(singer_lut_path, "r") as singer_lut_path:
|
97 |
+
singer_lut = json.load(singer_lut_path)
|
98 |
+
for singer in singer_lut.keys():
|
99 |
+
if singer not in singers:
|
100 |
+
singers[singer] = len(singers)
|
101 |
+
|
102 |
+
with open(
|
103 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
|
104 |
+
) as singer_file:
|
105 |
+
json.dump(singers, singer_file, indent=4, ensure_ascii=False)
|
106 |
+
print(
|
107 |
+
"singers have been dumped to {}".format(
|
108 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
|
109 |
+
)
|
110 |
+
)
|
111 |
+
return singers
|
models/svc/comosvc/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
models/svc/comosvc/comosvc.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""Adapted from https://github.com/zhenye234/CoMoSpeech"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import copy
|
11 |
+
import numpy as np
|
12 |
+
import math
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
|
15 |
+
from utils.ssim import SSIM
|
16 |
+
|
17 |
+
from models.svc.transformer.conformer import Conformer, BaseModule
|
18 |
+
from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
|
19 |
+
from models.svc.comosvc.utils import slice_segments, rand_ids_segments
|
20 |
+
|
21 |
+
|
22 |
+
class Consistency(nn.Module):
|
23 |
+
def __init__(self, cfg, distill=False):
|
24 |
+
super().__init__()
|
25 |
+
self.cfg = cfg
|
26 |
+
# self.denoise_fn = GradLogPEstimator2d(96)
|
27 |
+
self.denoise_fn = DiffusionWrapper(self.cfg)
|
28 |
+
self.cfg = cfg.model.comosvc
|
29 |
+
self.teacher = not distill
|
30 |
+
self.P_mean = self.cfg.P_mean
|
31 |
+
self.P_std = self.cfg.P_std
|
32 |
+
self.sigma_data = self.cfg.sigma_data
|
33 |
+
self.sigma_min = self.cfg.sigma_min
|
34 |
+
self.sigma_max = self.cfg.sigma_max
|
35 |
+
self.rho = self.cfg.rho
|
36 |
+
self.N = self.cfg.n_timesteps
|
37 |
+
self.ssim_loss = SSIM()
|
38 |
+
|
39 |
+
# Time step discretization
|
40 |
+
step_indices = torch.arange(self.N)
|
41 |
+
# karras boundaries formula
|
42 |
+
t_steps = (
|
43 |
+
self.sigma_min ** (1 / self.rho)
|
44 |
+
+ step_indices
|
45 |
+
/ (self.N - 1)
|
46 |
+
* (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
|
47 |
+
) ** self.rho
|
48 |
+
self.t_steps = torch.cat(
|
49 |
+
[torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)]
|
50 |
+
)
|
51 |
+
|
52 |
+
def init_consistency_training(self):
|
53 |
+
self.denoise_fn_ema = copy.deepcopy(self.denoise_fn)
|
54 |
+
self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn)
|
55 |
+
|
56 |
+
def EDMPrecond(self, x, sigma, cond, denoise_fn, mask, spk=None):
|
57 |
+
"""
|
58 |
+
karras diffusion reverse process
|
59 |
+
|
60 |
+
Args:
|
61 |
+
x: noisy mel-spectrogram [B x n_mel x L]
|
62 |
+
sigma: noise level [B x 1 x 1]
|
63 |
+
cond: output of conformer encoder [B x n_mel x L]
|
64 |
+
denoise_fn: denoiser neural network e.g. DilatedCNN
|
65 |
+
mask: mask of padded frames [B x n_mel x L]
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
denoised mel-spectrogram [B x n_mel x L]
|
69 |
+
"""
|
70 |
+
sigma = sigma.reshape(-1, 1, 1)
|
71 |
+
|
72 |
+
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
73 |
+
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
|
74 |
+
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
|
75 |
+
c_noise = sigma.log() / 4
|
76 |
+
|
77 |
+
x_in = c_in * x
|
78 |
+
x_in = x_in.transpose(1, 2)
|
79 |
+
x = x.transpose(1, 2)
|
80 |
+
cond = cond.transpose(1, 2)
|
81 |
+
F_x = denoise_fn(x_in, c_noise.squeeze(), cond)
|
82 |
+
# F_x = denoise_fn((c_in * x), mask, cond, c_noise.flatten())
|
83 |
+
D_x = c_skip * x + c_out * (F_x)
|
84 |
+
D_x = D_x.transpose(1, 2)
|
85 |
+
return D_x
|
86 |
+
|
87 |
+
def EDMLoss(self, x_start, cond, mask):
|
88 |
+
"""
|
89 |
+
compute loss for EDM model
|
90 |
+
|
91 |
+
Args:
|
92 |
+
x_start: ground truth mel-spectrogram [B x n_mel x L]
|
93 |
+
cond: output of conformer encoder [B x n_mel x L]
|
94 |
+
mask: mask of padded frames [B x n_mel x L]
|
95 |
+
"""
|
96 |
+
rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device)
|
97 |
+
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
|
98 |
+
weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
|
99 |
+
|
100 |
+
# follow Grad-TTS, start from Gaussian noise with mean cond and std I
|
101 |
+
noise = (torch.randn_like(x_start) + cond) * sigma
|
102 |
+
D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn, mask)
|
103 |
+
loss = weight * ((D_yn - x_start) ** 2)
|
104 |
+
loss = torch.sum(loss * mask) / torch.sum(mask)
|
105 |
+
return loss
|
106 |
+
|
107 |
+
def round_sigma(self, sigma):
|
108 |
+
return torch.as_tensor(sigma)
|
109 |
+
|
110 |
+
def edm_sampler(
|
111 |
+
self,
|
112 |
+
latents,
|
113 |
+
cond,
|
114 |
+
nonpadding,
|
115 |
+
num_steps=50,
|
116 |
+
sigma_min=0.002,
|
117 |
+
sigma_max=80,
|
118 |
+
rho=7,
|
119 |
+
S_churn=0,
|
120 |
+
S_min=0,
|
121 |
+
S_max=float("inf"),
|
122 |
+
S_noise=1,
|
123 |
+
# S_churn=40 ,S_min=0.05,S_max=50,S_noise=1.003,# S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
124 |
+
# S_churn=30 ,S_min=0.01,S_max=30,S_noise=1.007,
|
125 |
+
# S_churn=30 ,S_min=0.01,S_max=1,S_noise=1.007,
|
126 |
+
# S_churn=80 ,S_min=0.05,S_max=50,S_noise=1.003,
|
127 |
+
):
|
128 |
+
"""
|
129 |
+
karras diffusion sampler
|
130 |
+
|
131 |
+
Args:
|
132 |
+
latents: noisy mel-spectrogram [B x n_mel x L]
|
133 |
+
cond: output of conformer encoder [B x n_mel x L]
|
134 |
+
nonpadding: mask of padded frames [B x n_mel x L]
|
135 |
+
num_steps: number of steps for diffusion inference
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
denoised mel-spectrogram [B x n_mel x L]
|
139 |
+
"""
|
140 |
+
# Time step discretization.
|
141 |
+
step_indices = torch.arange(num_steps, device=latents.device)
|
142 |
+
|
143 |
+
num_steps = num_steps + 1
|
144 |
+
t_steps = (
|
145 |
+
sigma_max ** (1 / rho)
|
146 |
+
+ step_indices
|
147 |
+
/ (num_steps - 1)
|
148 |
+
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
|
149 |
+
) ** rho
|
150 |
+
t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])
|
151 |
+
|
152 |
+
# Main sampling loop.
|
153 |
+
x_next = latents * t_steps[0]
|
154 |
+
# wrap in tqdm for progress bar
|
155 |
+
bar = tqdm(enumerate(zip(t_steps[:-1], t_steps[1:])))
|
156 |
+
for i, (t_cur, t_next) in bar:
|
157 |
+
x_cur = x_next
|
158 |
+
# Increase noise temporarily.
|
159 |
+
gamma = (
|
160 |
+
min(S_churn / num_steps, np.sqrt(2) - 1)
|
161 |
+
if S_min <= t_cur <= S_max
|
162 |
+
else 0
|
163 |
+
)
|
164 |
+
t_hat = self.round_sigma(t_cur + gamma * t_cur)
|
165 |
+
t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device)
|
166 |
+
t[:, 0, 0] = t_hat
|
167 |
+
t_hat = t
|
168 |
+
x_hat = x_cur + (
|
169 |
+
t_hat**2 - t_cur**2
|
170 |
+
).sqrt() * S_noise * torch.randn_like(x_cur)
|
171 |
+
# Euler step.
|
172 |
+
denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn, nonpadding)
|
173 |
+
d_cur = (x_hat - denoised) / t_hat
|
174 |
+
x_next = x_hat + (t_next - t_hat) * d_cur
|
175 |
+
|
176 |
+
return x_next
|
177 |
+
|
178 |
+
def CTLoss_D(self, y, cond, mask):
|
179 |
+
"""
|
180 |
+
compute loss for consistency distillation
|
181 |
+
|
182 |
+
Args:
|
183 |
+
y: ground truth mel-spectrogram [B x n_mel x L]
|
184 |
+
cond: output of conformer encoder [B x n_mel x L]
|
185 |
+
mask: mask of padded frames [B x n_mel x L]
|
186 |
+
"""
|
187 |
+
with torch.no_grad():
|
188 |
+
mu = 0.95
|
189 |
+
for p, ema_p in zip(
|
190 |
+
self.denoise_fn.parameters(), self.denoise_fn_ema.parameters()
|
191 |
+
):
|
192 |
+
ema_p.mul_(mu).add_(p, alpha=1 - mu)
|
193 |
+
|
194 |
+
n = torch.randint(1, self.N, (y.shape[0],))
|
195 |
+
z = torch.randn_like(y) + cond
|
196 |
+
|
197 |
+
tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device)
|
198 |
+
f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn, mask)
|
199 |
+
|
200 |
+
with torch.no_grad():
|
201 |
+
tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device)
|
202 |
+
|
203 |
+
# euler step
|
204 |
+
x_hat = y + tn_1 * z
|
205 |
+
denoised = self.EDMPrecond(
|
206 |
+
x_hat, tn_1, cond, self.denoise_fn_pretrained, mask
|
207 |
+
)
|
208 |
+
d_cur = (x_hat - denoised) / tn_1
|
209 |
+
y_tn = x_hat + (tn - tn_1) * d_cur
|
210 |
+
|
211 |
+
f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema, mask)
|
212 |
+
|
213 |
+
# loss = (f_theta - f_theta_ema.detach()) ** 2
|
214 |
+
# loss = torch.sum(loss * mask) / torch.sum(mask)
|
215 |
+
loss = self.ssim_loss(f_theta, f_theta_ema.detach())
|
216 |
+
loss = torch.sum(loss * mask) / torch.sum(mask)
|
217 |
+
|
218 |
+
return loss
|
219 |
+
|
220 |
+
def get_t_steps(self, N):
|
221 |
+
N = N + 1
|
222 |
+
step_indices = torch.arange(N) # , device=latents.device)
|
223 |
+
t_steps = (
|
224 |
+
self.sigma_min ** (1 / self.rho)
|
225 |
+
+ step_indices
|
226 |
+
/ (N - 1)
|
227 |
+
* (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
|
228 |
+
) ** self.rho
|
229 |
+
|
230 |
+
return t_steps.flip(0)
|
231 |
+
|
232 |
+
def CT_sampler(self, latents, cond, nonpadding, t_steps=1):
|
233 |
+
"""
|
234 |
+
consistency distillation sampler
|
235 |
+
|
236 |
+
Args:
|
237 |
+
latents: noisy mel-spectrogram [B x n_mel x L]
|
238 |
+
cond: output of conformer encoder [B x n_mel x L]
|
239 |
+
nonpadding: mask of padded frames [B x n_mel x L]
|
240 |
+
t_steps: number of steps for diffusion inference
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
denoised mel-spectrogram [B x n_mel x L]
|
244 |
+
"""
|
245 |
+
# one-step
|
246 |
+
if t_steps == 1:
|
247 |
+
t_steps = [80]
|
248 |
+
# multi-step
|
249 |
+
else:
|
250 |
+
t_steps = self.get_t_steps(t_steps)
|
251 |
+
|
252 |
+
t_steps = torch.as_tensor(t_steps).to(latents.device)
|
253 |
+
latents = latents * t_steps[0]
|
254 |
+
_t = torch.zeros((latents.shape[0], 1, 1), device=latents.device)
|
255 |
+
_t[:, 0, 0] = t_steps
|
256 |
+
x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema, nonpadding)
|
257 |
+
|
258 |
+
for t in t_steps[1:-1]:
|
259 |
+
z = torch.randn_like(x) + cond
|
260 |
+
x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z
|
261 |
+
_t = torch.zeros((x.shape[0], 1, 1), device=x.device)
|
262 |
+
_t[:, 0, 0] = t
|
263 |
+
t = _t
|
264 |
+
print(t)
|
265 |
+
x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema, nonpadding)
|
266 |
+
return x
|
267 |
+
|
268 |
+
def forward(self, x, nonpadding, cond, t_steps=1, infer=False):
|
269 |
+
"""
|
270 |
+
calculate loss or sample mel-spectrogram
|
271 |
+
|
272 |
+
Args:
|
273 |
+
x:
|
274 |
+
training: ground truth mel-spectrogram [B x n_mel x L]
|
275 |
+
inference: output of encoder [B x n_mel x L]
|
276 |
+
"""
|
277 |
+
if self.teacher: # teacher model -- karras diffusion
|
278 |
+
if not infer:
|
279 |
+
loss = self.EDMLoss(x, cond, nonpadding)
|
280 |
+
return loss
|
281 |
+
else:
|
282 |
+
shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
|
283 |
+
x = torch.randn(shape, device=x.device) + cond
|
284 |
+
x = self.edm_sampler(x, cond, nonpadding, t_steps)
|
285 |
+
|
286 |
+
return x
|
287 |
+
else: # Consistency distillation
|
288 |
+
if not infer:
|
289 |
+
loss = self.CTLoss_D(x, cond, nonpadding)
|
290 |
+
return loss
|
291 |
+
|
292 |
+
else:
|
293 |
+
shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
|
294 |
+
x = torch.randn(shape, device=x.device) + cond
|
295 |
+
x = self.CT_sampler(x, cond, nonpadding, t_steps=1)
|
296 |
+
|
297 |
+
return x
|
298 |
+
|
299 |
+
|
300 |
+
class ComoSVC(BaseModule):
|
301 |
+
def __init__(self, cfg):
|
302 |
+
super().__init__()
|
303 |
+
self.cfg = cfg
|
304 |
+
self.cfg.model.comosvc.n_mel = self.cfg.preprocess.n_mel
|
305 |
+
self.distill = self.cfg.model.comosvc.distill
|
306 |
+
self.encoder = Conformer(self.cfg.model.comosvc)
|
307 |
+
self.decoder = Consistency(self.cfg, distill=self.distill)
|
308 |
+
self.ssim_loss = SSIM()
|
309 |
+
|
310 |
+
@torch.no_grad()
|
311 |
+
def forward(self, x_mask, x, n_timesteps, temperature=1.0):
|
312 |
+
"""
|
313 |
+
Generates mel-spectrogram from pitch, content vector, energy. Returns:
|
314 |
+
1. encoder outputs (from conformer)
|
315 |
+
2. decoder outputs (from diffusion-based decoder)
|
316 |
+
|
317 |
+
Args:
|
318 |
+
x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
|
319 |
+
x : output of encoder framework. [B x L x d_condition]
|
320 |
+
n_timesteps : number of steps to use for reverse diffusion in decoder.
|
321 |
+
temperature : controls variance of terminal distribution.
|
322 |
+
"""
|
323 |
+
|
324 |
+
# Get encoder_outputs `mu_x`
|
325 |
+
mu_x = self.encoder(x, x_mask)
|
326 |
+
encoder_outputs = mu_x
|
327 |
+
|
328 |
+
mu_x = mu_x.transpose(1, 2)
|
329 |
+
x_mask = x_mask.transpose(1, 2)
|
330 |
+
|
331 |
+
# Generate sample by performing reverse dynamics
|
332 |
+
decoder_outputs = self.decoder(
|
333 |
+
mu_x, x_mask, mu_x, t_steps=n_timesteps, infer=True
|
334 |
+
)
|
335 |
+
decoder_outputs = decoder_outputs.transpose(1, 2)
|
336 |
+
return encoder_outputs, decoder_outputs
|
337 |
+
|
338 |
+
def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False):
|
339 |
+
"""
|
340 |
+
Computes 2 losses:
|
341 |
+
1. prior loss: loss between mel-spectrogram and encoder outputs.
|
342 |
+
2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
|
346 |
+
x : output of encoder framework. [B x L x d_condition]
|
347 |
+
mel : ground truth mel-spectrogram. [B x L x n_mel]
|
348 |
+
"""
|
349 |
+
|
350 |
+
mu_x = self.encoder(x, x_mask)
|
351 |
+
# prior loss
|
352 |
+
prior_loss = torch.sum(
|
353 |
+
0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask
|
354 |
+
)
|
355 |
+
prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel)
|
356 |
+
# ssim loss
|
357 |
+
ssim_loss = self.ssim_loss(mu_x, mel)
|
358 |
+
ssim_loss = torch.sum(ssim_loss * x_mask) / torch.sum(x_mask)
|
359 |
+
|
360 |
+
x_mask = x_mask.transpose(1, 2)
|
361 |
+
mu_x = mu_x.transpose(1, 2)
|
362 |
+
mel = mel.transpose(1, 2)
|
363 |
+
if not self.distill and skip_diff:
|
364 |
+
diff_loss = prior_loss.clone()
|
365 |
+
diff_loss.fill_(0)
|
366 |
+
|
367 |
+
# Cut a small segment of mel-spectrogram in order to increase batch size
|
368 |
+
else:
|
369 |
+
if self.distill:
|
370 |
+
mu_y = mu_x.detach()
|
371 |
+
else:
|
372 |
+
mu_y = mu_x
|
373 |
+
mask_y = x_mask
|
374 |
+
|
375 |
+
diff_loss = self.decoder(mel, mask_y, mu_y, infer=False)
|
376 |
+
|
377 |
+
return ssim_loss, prior_loss, diff_loss
|
models/svc/comosvc/comosvc_inference.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from models.svc.base import SVCInference
|
9 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
10 |
+
from models.svc.comosvc.comosvc import ComoSVC
|
11 |
+
|
12 |
+
|
13 |
+
class ComoSVCInference(SVCInference):
|
14 |
+
def __init__(self, args, cfg, infer_type="from_dataset"):
|
15 |
+
SVCInference.__init__(self, args, cfg, infer_type)
|
16 |
+
|
17 |
+
def _build_model(self):
|
18 |
+
# TODO: sort out the config
|
19 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
20 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
21 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
22 |
+
self.acoustic_mapper = ComoSVC(self.cfg)
|
23 |
+
if self.cfg.model.comosvc.distill:
|
24 |
+
self.acoustic_mapper.decoder.init_consistency_training()
|
25 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
26 |
+
return model
|
27 |
+
|
28 |
+
def _inference_each_batch(self, batch_data):
|
29 |
+
device = self.accelerator.device
|
30 |
+
for k, v in batch_data.items():
|
31 |
+
batch_data[k] = v.to(device)
|
32 |
+
|
33 |
+
cond = self.condition_encoder(batch_data)
|
34 |
+
mask = batch_data["mask"]
|
35 |
+
encoder_pred, decoder_pred = self.acoustic_mapper(
|
36 |
+
mask, cond, self.cfg.inference.comosvc.inference_steps
|
37 |
+
)
|
38 |
+
|
39 |
+
return decoder_pred
|
models/svc/comosvc/comosvc_trainer.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
import json5
|
9 |
+
from collections import OrderedDict
|
10 |
+
from tqdm import tqdm
|
11 |
+
import json
|
12 |
+
import shutil
|
13 |
+
|
14 |
+
from models.svc.base import SVCTrainer
|
15 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
16 |
+
from models.svc.comosvc.comosvc import ComoSVC
|
17 |
+
|
18 |
+
|
19 |
+
class ComoSVCTrainer(SVCTrainer):
|
20 |
+
r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
|
21 |
+
implements ``_build_model`` and ``_forward_step`` methods.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, args=None, cfg=None):
|
25 |
+
SVCTrainer.__init__(self, args, cfg)
|
26 |
+
self.distill = cfg.model.comosvc.distill
|
27 |
+
self.skip_diff = True
|
28 |
+
if self.distill: # and args.resume is None:
|
29 |
+
self.teacher_model_path = cfg.model.teacher_model_path
|
30 |
+
self.teacher_state_dict = self._load_teacher_state_dict()
|
31 |
+
self._load_teacher_model(self.teacher_state_dict)
|
32 |
+
self.acoustic_mapper.decoder.init_consistency_training()
|
33 |
+
|
34 |
+
### Following are methods only for comoSVC models ###
|
35 |
+
def _load_teacher_state_dict(self):
|
36 |
+
self.checkpoint_file = self.teacher_model_path
|
37 |
+
print("Load teacher acoustic model from {}".format(self.checkpoint_file))
|
38 |
+
raw_state_dict = torch.load(self.checkpoint_file) # , map_location=self.device)
|
39 |
+
return raw_state_dict
|
40 |
+
|
41 |
+
def _load_teacher_model(self, state_dict):
|
42 |
+
raw_dict = state_dict
|
43 |
+
clean_dict = OrderedDict()
|
44 |
+
for k, v in raw_dict.items():
|
45 |
+
if k.startswith("module."):
|
46 |
+
clean_dict[k[7:]] = v
|
47 |
+
else:
|
48 |
+
clean_dict[k] = v
|
49 |
+
self.model.load_state_dict(clean_dict)
|
50 |
+
|
51 |
+
def _build_model(self):
|
52 |
+
r"""Build the model for training. This function is called in ``__init__`` function."""
|
53 |
+
|
54 |
+
# TODO: sort out the config
|
55 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
56 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
57 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
58 |
+
self.acoustic_mapper = ComoSVC(self.cfg)
|
59 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
60 |
+
return model
|
61 |
+
|
62 |
+
def _forward_step(self, batch):
|
63 |
+
r"""Forward step for training and inference. This function is called
|
64 |
+
in ``_train_step`` & ``_test_step`` function.
|
65 |
+
"""
|
66 |
+
loss = {}
|
67 |
+
mask = batch["mask"]
|
68 |
+
mel_input = batch["mel"]
|
69 |
+
cond = self.condition_encoder(batch)
|
70 |
+
if self.distill:
|
71 |
+
cond = cond.detach()
|
72 |
+
self.skip_diff = True if self.step < self.cfg.train.fast_steps else False
|
73 |
+
ssim_loss, prior_loss, diff_loss = self.acoustic_mapper.compute_loss(
|
74 |
+
mask, cond, mel_input, skip_diff=self.skip_diff
|
75 |
+
)
|
76 |
+
if self.distill:
|
77 |
+
loss["distil_loss"] = diff_loss
|
78 |
+
else:
|
79 |
+
loss["ssim_loss_encoder"] = ssim_loss
|
80 |
+
loss["prior_loss_encoder"] = prior_loss
|
81 |
+
loss["diffusion_loss_decoder"] = diff_loss
|
82 |
+
|
83 |
+
return loss
|
84 |
+
|
85 |
+
def _train_epoch(self):
|
86 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
87 |
+
one epoch. See ``train_loop`` for usage.
|
88 |
+
"""
|
89 |
+
self.model.train()
|
90 |
+
epoch_sum_loss: float = 0.0
|
91 |
+
epoch_step: int = 0
|
92 |
+
for batch in tqdm(
|
93 |
+
self.train_dataloader,
|
94 |
+
desc=f"Training Epoch {self.epoch}",
|
95 |
+
unit="batch",
|
96 |
+
colour="GREEN",
|
97 |
+
leave=False,
|
98 |
+
dynamic_ncols=True,
|
99 |
+
smoothing=0.04,
|
100 |
+
disable=not self.accelerator.is_main_process,
|
101 |
+
):
|
102 |
+
# Do training step and BP
|
103 |
+
with self.accelerator.accumulate(self.model):
|
104 |
+
loss = self._train_step(batch)
|
105 |
+
total_loss = 0
|
106 |
+
for k, v in loss.items():
|
107 |
+
total_loss += v
|
108 |
+
self.accelerator.backward(total_loss)
|
109 |
+
enc_grad_norm = torch.nn.utils.clip_grad_norm_(
|
110 |
+
self.acoustic_mapper.encoder.parameters(), max_norm=1
|
111 |
+
)
|
112 |
+
dec_grad_norm = torch.nn.utils.clip_grad_norm_(
|
113 |
+
self.acoustic_mapper.decoder.parameters(), max_norm=1
|
114 |
+
)
|
115 |
+
self.optimizer.step()
|
116 |
+
self.optimizer.zero_grad()
|
117 |
+
self.batch_count += 1
|
118 |
+
|
119 |
+
# Update info for each step
|
120 |
+
# TODO: step means BP counts or batch counts?
|
121 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
122 |
+
epoch_sum_loss += total_loss
|
123 |
+
log_info = {}
|
124 |
+
for k, v in loss.items():
|
125 |
+
key = "Step/Train Loss/{}".format(k)
|
126 |
+
log_info[key] = v
|
127 |
+
log_info["Step/Learning Rate"]: self.optimizer.param_groups[0]["lr"]
|
128 |
+
self.accelerator.log(
|
129 |
+
log_info,
|
130 |
+
step=self.step,
|
131 |
+
)
|
132 |
+
self.step += 1
|
133 |
+
epoch_step += 1
|
134 |
+
|
135 |
+
self.accelerator.wait_for_everyone()
|
136 |
+
return (
|
137 |
+
epoch_sum_loss
|
138 |
+
/ len(self.train_dataloader)
|
139 |
+
* self.cfg.train.gradient_accumulation_step,
|
140 |
+
loss,
|
141 |
+
)
|
142 |
+
|
143 |
+
def train_loop(self):
|
144 |
+
r"""Training loop. The public entry of training process."""
|
145 |
+
# Wait everyone to prepare before we move on
|
146 |
+
self.accelerator.wait_for_everyone()
|
147 |
+
# dump config file
|
148 |
+
if self.accelerator.is_main_process:
|
149 |
+
self.__dump_cfg(self.config_save_path)
|
150 |
+
self.model.train()
|
151 |
+
self.optimizer.zero_grad()
|
152 |
+
# Wait to ensure good to go
|
153 |
+
self.accelerator.wait_for_everyone()
|
154 |
+
while self.epoch < self.max_epoch:
|
155 |
+
self.logger.info("\n")
|
156 |
+
self.logger.info("-" * 32)
|
157 |
+
self.logger.info("Epoch {}: ".format(self.epoch))
|
158 |
+
|
159 |
+
### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
|
160 |
+
### It's inconvenient for the model with multiple losses
|
161 |
+
# Do training & validating epoch
|
162 |
+
train_loss, loss = self._train_epoch()
|
163 |
+
self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
|
164 |
+
for k, v in loss.items():
|
165 |
+
self.logger.info(" |- Train/Loss/{}: {:.6f}".format(k, v))
|
166 |
+
valid_loss = self._valid_epoch()
|
167 |
+
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
|
168 |
+
self.accelerator.log(
|
169 |
+
{"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
|
170 |
+
step=self.epoch,
|
171 |
+
)
|
172 |
+
|
173 |
+
self.accelerator.wait_for_everyone()
|
174 |
+
# TODO: what is scheduler?
|
175 |
+
self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
|
176 |
+
|
177 |
+
# Check if hit save_checkpoint_stride and run_eval
|
178 |
+
run_eval = False
|
179 |
+
if self.accelerator.is_main_process:
|
180 |
+
save_checkpoint = False
|
181 |
+
hit_dix = []
|
182 |
+
for i, num in enumerate(self.save_checkpoint_stride):
|
183 |
+
if self.epoch % num == 0:
|
184 |
+
save_checkpoint = True
|
185 |
+
hit_dix.append(i)
|
186 |
+
run_eval |= self.run_eval[i]
|
187 |
+
|
188 |
+
self.accelerator.wait_for_everyone()
|
189 |
+
if (
|
190 |
+
self.accelerator.is_main_process
|
191 |
+
and save_checkpoint
|
192 |
+
and (self.distill or not self.skip_diff)
|
193 |
+
):
|
194 |
+
path = os.path.join(
|
195 |
+
self.checkpoint_dir,
|
196 |
+
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
197 |
+
self.epoch, self.step, train_loss
|
198 |
+
),
|
199 |
+
)
|
200 |
+
self.accelerator.save_state(path)
|
201 |
+
json.dump(
|
202 |
+
self.checkpoints_path,
|
203 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
204 |
+
ensure_ascii=False,
|
205 |
+
indent=4,
|
206 |
+
)
|
207 |
+
|
208 |
+
# Remove old checkpoints
|
209 |
+
to_remove = []
|
210 |
+
for idx in hit_dix:
|
211 |
+
self.checkpoints_path[idx].append(path)
|
212 |
+
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
|
213 |
+
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
|
214 |
+
|
215 |
+
# Search conflicts
|
216 |
+
total = set()
|
217 |
+
for i in self.checkpoints_path:
|
218 |
+
total |= set(i)
|
219 |
+
do_remove = set()
|
220 |
+
for idx, path in to_remove[::-1]:
|
221 |
+
if path in total:
|
222 |
+
self.checkpoints_path[idx].insert(0, path)
|
223 |
+
else:
|
224 |
+
do_remove.add(path)
|
225 |
+
|
226 |
+
# Remove old checkpoints
|
227 |
+
for path in do_remove:
|
228 |
+
shutil.rmtree(path, ignore_errors=True)
|
229 |
+
self.logger.debug(f"Remove old checkpoint: {path}")
|
230 |
+
|
231 |
+
self.accelerator.wait_for_everyone()
|
232 |
+
if run_eval:
|
233 |
+
# TODO: run evaluation
|
234 |
+
pass
|
235 |
+
|
236 |
+
# Update info for each epoch
|
237 |
+
self.epoch += 1
|
238 |
+
|
239 |
+
# Finish training and save final checkpoint
|
240 |
+
self.accelerator.wait_for_everyone()
|
241 |
+
if self.accelerator.is_main_process:
|
242 |
+
self.accelerator.save_state(
|
243 |
+
os.path.join(
|
244 |
+
self.checkpoint_dir,
|
245 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
246 |
+
self.epoch, self.step, valid_loss
|
247 |
+
),
|
248 |
+
)
|
249 |
+
)
|
250 |
+
self.accelerator.end_training()
|
251 |
+
|
252 |
+
@torch.inference_mode()
|
253 |
+
def _valid_epoch(self):
|
254 |
+
r"""Testing epoch. Should return average loss of a batch (sample) over
|
255 |
+
one epoch. See ``train_loop`` for usage.
|
256 |
+
"""
|
257 |
+
self.model.eval()
|
258 |
+
epoch_sum_loss = 0.0
|
259 |
+
for batch in tqdm(
|
260 |
+
self.valid_dataloader,
|
261 |
+
desc=f"Validating Epoch {self.epoch}",
|
262 |
+
unit="batch",
|
263 |
+
colour="GREEN",
|
264 |
+
leave=False,
|
265 |
+
dynamic_ncols=True,
|
266 |
+
smoothing=0.04,
|
267 |
+
disable=not self.accelerator.is_main_process,
|
268 |
+
):
|
269 |
+
batch_loss = self._valid_step(batch)
|
270 |
+
for k, v in batch_loss.items():
|
271 |
+
epoch_sum_loss += v
|
272 |
+
|
273 |
+
self.accelerator.wait_for_everyone()
|
274 |
+
return epoch_sum_loss / len(self.valid_dataloader)
|
275 |
+
|
276 |
+
@staticmethod
|
277 |
+
def __count_parameters(model):
|
278 |
+
model_param = 0.0
|
279 |
+
if isinstance(model, dict):
|
280 |
+
for key, value in model.items():
|
281 |
+
model_param += sum(p.numel() for p in model[key].parameters())
|
282 |
+
else:
|
283 |
+
model_param = sum(p.numel() for p in model.parameters())
|
284 |
+
return model_param
|
285 |
+
|
286 |
+
def __dump_cfg(self, path):
|
287 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
288 |
+
json5.dump(
|
289 |
+
self.cfg,
|
290 |
+
open(path, "w"),
|
291 |
+
indent=4,
|
292 |
+
sort_keys=True,
|
293 |
+
ensure_ascii=False,
|
294 |
+
quote_keys=True,
|
295 |
+
)
|
models/svc/comosvc/utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def slice_segments(x, ids_str, segment_size=200):
|
10 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
11 |
+
for i in range(x.size(0)):
|
12 |
+
idx_str = ids_str[i]
|
13 |
+
idx_end = idx_str + segment_size
|
14 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
15 |
+
return ret
|
16 |
+
|
17 |
+
|
18 |
+
def rand_ids_segments(lengths, segment_size=200):
|
19 |
+
b = lengths.shape[0]
|
20 |
+
ids_str_max = lengths - segment_size
|
21 |
+
ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to(
|
22 |
+
dtype=torch.long
|
23 |
+
)
|
24 |
+
return ids_str
|
25 |
+
|
26 |
+
|
27 |
+
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
28 |
+
while True:
|
29 |
+
if length % (2**num_downsamplings_in_unet) == 0:
|
30 |
+
return length
|
31 |
+
length += 1
|
models/svc/diffusion/__init__.py
ADDED
File without changes
|
models/svc/diffusion/diffusion_inference.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
|
8 |
+
|
9 |
+
from models.svc.base import SVCInference
|
10 |
+
from models.svc.diffusion.diffusion_inference_pipeline import DiffusionInferencePipeline
|
11 |
+
from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
|
12 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
13 |
+
|
14 |
+
|
15 |
+
class DiffusionInference(SVCInference):
|
16 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
17 |
+
SVCInference.__init__(self, args, cfg, infer_type)
|
18 |
+
|
19 |
+
settings = {
|
20 |
+
**cfg.model.diffusion.scheduler_settings,
|
21 |
+
**cfg.inference.diffusion.scheduler_settings,
|
22 |
+
}
|
23 |
+
settings.pop("num_inference_timesteps")
|
24 |
+
|
25 |
+
if cfg.inference.diffusion.scheduler.lower() == "ddpm":
|
26 |
+
self.scheduler = DDPMScheduler(**settings)
|
27 |
+
self.logger.info("Using DDPM scheduler.")
|
28 |
+
elif cfg.inference.diffusion.scheduler.lower() == "ddim":
|
29 |
+
self.scheduler = DDIMScheduler(**settings)
|
30 |
+
self.logger.info("Using DDIM scheduler.")
|
31 |
+
elif cfg.inference.diffusion.scheduler.lower() == "pndm":
|
32 |
+
self.scheduler = PNDMScheduler(**settings)
|
33 |
+
self.logger.info("Using PNDM scheduler.")
|
34 |
+
else:
|
35 |
+
raise NotImplementedError(
|
36 |
+
"Unsupported scheduler type: {}".format(
|
37 |
+
cfg.inference.diffusion.scheduler.lower()
|
38 |
+
)
|
39 |
+
)
|
40 |
+
|
41 |
+
self.pipeline = DiffusionInferencePipeline(
|
42 |
+
self.model[1],
|
43 |
+
self.scheduler,
|
44 |
+
cfg.inference.diffusion.scheduler_settings.num_inference_timesteps,
|
45 |
+
)
|
46 |
+
|
47 |
+
def _build_model(self):
|
48 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
49 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
50 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
51 |
+
self.acoustic_mapper = DiffusionWrapper(self.cfg)
|
52 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
53 |
+
return model
|
54 |
+
|
55 |
+
def _inference_each_batch(self, batch_data):
|
56 |
+
device = self.accelerator.device
|
57 |
+
for k, v in batch_data.items():
|
58 |
+
batch_data[k] = v.to(device)
|
59 |
+
|
60 |
+
conditioner = self.model[0](batch_data)
|
61 |
+
noise = torch.randn_like(batch_data["mel"], device=device)
|
62 |
+
y_pred = self.pipeline(noise, conditioner)
|
63 |
+
return y_pred
|
models/svc/diffusion/diffusion_inference_pipeline.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from diffusers import DiffusionPipeline
|
8 |
+
|
9 |
+
|
10 |
+
class DiffusionInferencePipeline(DiffusionPipeline):
|
11 |
+
def __init__(self, network, scheduler, num_inference_timesteps=1000):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
self.register_modules(network=network, scheduler=scheduler)
|
15 |
+
self.num_inference_timesteps = num_inference_timesteps
|
16 |
+
|
17 |
+
@torch.inference_mode()
|
18 |
+
def __call__(
|
19 |
+
self,
|
20 |
+
initial_noise: torch.Tensor,
|
21 |
+
conditioner: torch.Tensor = None,
|
22 |
+
):
|
23 |
+
r"""
|
24 |
+
Args:
|
25 |
+
initial_noise: The initial noise to be denoised.
|
26 |
+
conditioner:The conditioner.
|
27 |
+
n_inference_steps: The number of denoising steps. More denoising steps
|
28 |
+
usually lead to a higher quality at the expense of slower inference.
|
29 |
+
"""
|
30 |
+
|
31 |
+
mel = initial_noise
|
32 |
+
batch_size = mel.size(0)
|
33 |
+
self.scheduler.set_timesteps(self.num_inference_timesteps)
|
34 |
+
|
35 |
+
for t in self.progress_bar(self.scheduler.timesteps):
|
36 |
+
timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long)
|
37 |
+
|
38 |
+
# 1. predict noise model_output
|
39 |
+
model_output = self.network(mel, timestep, conditioner)
|
40 |
+
|
41 |
+
# 2. denoise, compute previous step: x_t -> x_t-1
|
42 |
+
mel = self.scheduler.step(model_output, t, mel).prev_sample
|
43 |
+
|
44 |
+
# 3. clamp
|
45 |
+
mel = mel.clamp(-1.0, 1.0)
|
46 |
+
|
47 |
+
return mel
|
models/svc/diffusion/diffusion_trainer.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from diffusers import DDPMScheduler
|
8 |
+
|
9 |
+
from models.svc.base import SVCTrainer
|
10 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
11 |
+
from .diffusion_wrapper import DiffusionWrapper
|
12 |
+
|
13 |
+
|
14 |
+
class DiffusionTrainer(SVCTrainer):
|
15 |
+
r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
|
16 |
+
implements ``_build_model`` and ``_forward_step`` methods.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, args=None, cfg=None):
|
20 |
+
SVCTrainer.__init__(self, args, cfg)
|
21 |
+
|
22 |
+
# Only for SVC tasks using diffusion
|
23 |
+
self.noise_scheduler = DDPMScheduler(
|
24 |
+
**self.cfg.model.diffusion.scheduler_settings,
|
25 |
+
)
|
26 |
+
self.diffusion_timesteps = (
|
27 |
+
self.cfg.model.diffusion.scheduler_settings.num_train_timesteps
|
28 |
+
)
|
29 |
+
|
30 |
+
### Following are methods only for diffusion models ###
|
31 |
+
def _build_model(self):
|
32 |
+
r"""Build the model for training. This function is called in ``__init__`` function."""
|
33 |
+
|
34 |
+
# TODO: sort out the config
|
35 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
36 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
37 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
38 |
+
self.acoustic_mapper = DiffusionWrapper(self.cfg)
|
39 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
40 |
+
|
41 |
+
num_of_params_encoder = self.count_parameters(self.condition_encoder)
|
42 |
+
num_of_params_am = self.count_parameters(self.acoustic_mapper)
|
43 |
+
num_of_params = num_of_params_encoder + num_of_params_am
|
44 |
+
log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format(
|
45 |
+
num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6
|
46 |
+
)
|
47 |
+
self.logger.info(log)
|
48 |
+
|
49 |
+
return model
|
50 |
+
|
51 |
+
def count_parameters(self, model):
|
52 |
+
model_param = 0.0
|
53 |
+
if isinstance(model, dict):
|
54 |
+
for key, value in model.items():
|
55 |
+
model_param += sum(p.numel() for p in model[key].parameters())
|
56 |
+
else:
|
57 |
+
model_param = sum(p.numel() for p in model.parameters())
|
58 |
+
return model_param
|
59 |
+
|
60 |
+
def _forward_step(self, batch):
|
61 |
+
r"""Forward step for training and inference. This function is called
|
62 |
+
in ``_train_step`` & ``_test_step`` function.
|
63 |
+
"""
|
64 |
+
|
65 |
+
device = self.accelerator.device
|
66 |
+
|
67 |
+
mel_input = batch["mel"]
|
68 |
+
noise = torch.randn_like(mel_input, device=device, dtype=torch.float32)
|
69 |
+
batch_size = mel_input.size(0)
|
70 |
+
timesteps = torch.randint(
|
71 |
+
0,
|
72 |
+
self.diffusion_timesteps,
|
73 |
+
(batch_size,),
|
74 |
+
device=device,
|
75 |
+
dtype=torch.long,
|
76 |
+
)
|
77 |
+
|
78 |
+
noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps)
|
79 |
+
conditioner = self.condition_encoder(batch)
|
80 |
+
|
81 |
+
y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner)
|
82 |
+
|
83 |
+
# TODO: Predict noise or gt should be configurable
|
84 |
+
loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"])
|
85 |
+
self._check_nan(loss, y_pred, noise)
|
86 |
+
|
87 |
+
# FIXME: Clarify that we should not divide it with batch size here
|
88 |
+
return loss
|
models/svc/diffusion/diffusion_wrapper.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from modules.diffusion import BiDilConv
|
9 |
+
from modules.encoder.position_encoder import PositionEncoder
|
10 |
+
|
11 |
+
|
12 |
+
class DiffusionWrapper(nn.Module):
|
13 |
+
def __init__(self, cfg):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.cfg = cfg
|
17 |
+
self.diff_cfg = cfg.model.diffusion
|
18 |
+
|
19 |
+
self.diff_encoder = PositionEncoder(
|
20 |
+
d_raw_emb=self.diff_cfg.step_encoder.dim_raw_embedding,
|
21 |
+
d_out=self.diff_cfg.bidilconv.base_channel,
|
22 |
+
d_mlp=self.diff_cfg.step_encoder.dim_hidden_layer,
|
23 |
+
activation_function=self.diff_cfg.step_encoder.activation,
|
24 |
+
n_layer=self.diff_cfg.step_encoder.num_layer,
|
25 |
+
max_period=self.diff_cfg.step_encoder.max_period,
|
26 |
+
)
|
27 |
+
|
28 |
+
# FIXME: Only support BiDilConv now for debug
|
29 |
+
if self.diff_cfg.model_type.lower() == "bidilconv":
|
30 |
+
self.neural_network = BiDilConv(
|
31 |
+
input_channel=self.cfg.preprocess.n_mel, **self.diff_cfg.bidilconv
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
raise ValueError(
|
35 |
+
f"Unsupported diffusion model type: {self.diff_cfg.model_type}"
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, x, t, c):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
x: [N, T, mel_band] of mel spectrogram
|
42 |
+
t: Diffusion time step with shape of [N]
|
43 |
+
c: [N, T, conditioner_size] of conditioner
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
[N, T, mel_band] of mel spectrogram
|
47 |
+
"""
|
48 |
+
|
49 |
+
assert (
|
50 |
+
x.size()[:-1] == c.size()[:-1]
|
51 |
+
), "x mismatch with c, got \n x: {} \n c: {}".format(x.size(), c.size())
|
52 |
+
assert x.size(0) == t.size(
|
53 |
+
0
|
54 |
+
), "x mismatch with t, got \n x: {} \n t: {}".format(x.size(), t.size())
|
55 |
+
assert t.dim() == 1, "t must be 1D tensor, got {}".format(t.dim())
|
56 |
+
|
57 |
+
N, T, mel_band = x.size()
|
58 |
+
|
59 |
+
x = x.transpose(1, 2).contiguous() # [N, mel_band, T]
|
60 |
+
c = c.transpose(1, 2).contiguous() # [N, conditioner_size, T]
|
61 |
+
t = self.diff_encoder(t).contiguous() # [N, base_channel]
|
62 |
+
|
63 |
+
h = self.neural_network(x, t, c)
|
64 |
+
h = h.transpose(1, 2).contiguous() # [N, T, mel_band]
|
65 |
+
|
66 |
+
assert h.size() == (
|
67 |
+
N,
|
68 |
+
T,
|
69 |
+
mel_band,
|
70 |
+
), "h mismatch with input x, got \n h: {} \n x: {}".format(
|
71 |
+
h.size(), (N, T, mel_band)
|
72 |
+
)
|
73 |
+
return h
|
models/svc/transformer/__init__.py
ADDED
File without changes
|
models/svc/transformer/conformer.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import torch.nn as nn
|
10 |
+
from utils.util import convert_pad_shape
|
11 |
+
|
12 |
+
|
13 |
+
class BaseModule(torch.nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super(BaseModule, self).__init__()
|
16 |
+
|
17 |
+
@property
|
18 |
+
def nparams(self):
|
19 |
+
"""
|
20 |
+
Returns number of trainable parameters of the module.
|
21 |
+
"""
|
22 |
+
num_params = 0
|
23 |
+
for name, param in self.named_parameters():
|
24 |
+
if param.requires_grad:
|
25 |
+
num_params += np.prod(param.detach().cpu().numpy().shape)
|
26 |
+
return num_params
|
27 |
+
|
28 |
+
def relocate_input(self, x: list):
|
29 |
+
"""
|
30 |
+
Relocates provided tensors to the same device set for the module.
|
31 |
+
"""
|
32 |
+
device = next(self.parameters()).device
|
33 |
+
for i in range(len(x)):
|
34 |
+
if isinstance(x[i], torch.Tensor) and x[i].device != device:
|
35 |
+
x[i] = x[i].to(device)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class LayerNorm(BaseModule):
|
40 |
+
def __init__(self, channels, eps=1e-4):
|
41 |
+
super(LayerNorm, self).__init__()
|
42 |
+
self.channels = channels
|
43 |
+
self.eps = eps
|
44 |
+
|
45 |
+
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
46 |
+
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
n_dims = len(x.shape)
|
50 |
+
mean = torch.mean(x, 1, keepdim=True)
|
51 |
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
52 |
+
|
53 |
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
54 |
+
|
55 |
+
shape = [1, -1] + [1] * (n_dims - 2)
|
56 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class ConvReluNorm(BaseModule):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
in_channels,
|
64 |
+
hidden_channels,
|
65 |
+
out_channels,
|
66 |
+
kernel_size,
|
67 |
+
n_layers,
|
68 |
+
p_dropout,
|
69 |
+
eps=1e-5,
|
70 |
+
):
|
71 |
+
super(ConvReluNorm, self).__init__()
|
72 |
+
self.in_channels = in_channels
|
73 |
+
self.hidden_channels = hidden_channels
|
74 |
+
self.out_channels = out_channels
|
75 |
+
self.kernel_size = kernel_size
|
76 |
+
self.n_layers = n_layers
|
77 |
+
self.p_dropout = p_dropout
|
78 |
+
self.eps = eps
|
79 |
+
|
80 |
+
self.conv_layers = torch.nn.ModuleList()
|
81 |
+
self.conv_layers.append(
|
82 |
+
torch.nn.Conv1d(
|
83 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
84 |
+
)
|
85 |
+
)
|
86 |
+
self.relu_drop = torch.nn.Sequential(
|
87 |
+
torch.nn.ReLU(), torch.nn.Dropout(p_dropout)
|
88 |
+
)
|
89 |
+
for _ in range(n_layers - 1):
|
90 |
+
self.conv_layers.append(
|
91 |
+
torch.nn.Conv1d(
|
92 |
+
hidden_channels,
|
93 |
+
hidden_channels,
|
94 |
+
kernel_size,
|
95 |
+
padding=kernel_size // 2,
|
96 |
+
)
|
97 |
+
)
|
98 |
+
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
99 |
+
self.proj.weight.data.zero_()
|
100 |
+
self.proj.bias.data.zero_()
|
101 |
+
|
102 |
+
def forward(self, x, x_mask):
|
103 |
+
for i in range(self.n_layers):
|
104 |
+
x = self.conv_layers[i](x * x_mask)
|
105 |
+
x = self.instance_norm(x, x_mask)
|
106 |
+
x = self.relu_drop(x)
|
107 |
+
x = self.proj(x)
|
108 |
+
return x * x_mask
|
109 |
+
|
110 |
+
def instance_norm(self, x, mask, return_mean_std=False):
|
111 |
+
mean, std = self.calc_mean_std(x, mask)
|
112 |
+
x = (x - mean) / std
|
113 |
+
if return_mean_std:
|
114 |
+
return x, mean, std
|
115 |
+
else:
|
116 |
+
return x
|
117 |
+
|
118 |
+
def calc_mean_std(self, x, mask=None):
|
119 |
+
x = x * mask
|
120 |
+
B, C = x.shape[:2]
|
121 |
+
mn = x.view(B, C, -1).mean(-1)
|
122 |
+
sd = (x.view(B, C, -1).var(-1) + self.eps).sqrt()
|
123 |
+
mn = mn.view(B, C, *((len(x.shape) - 2) * [1]))
|
124 |
+
sd = sd.view(B, C, *((len(x.shape) - 2) * [1]))
|
125 |
+
return mn, sd
|
126 |
+
|
127 |
+
|
128 |
+
class MultiHeadAttention(BaseModule):
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
channels,
|
132 |
+
out_channels,
|
133 |
+
n_heads,
|
134 |
+
window_size=None,
|
135 |
+
heads_share=True,
|
136 |
+
p_dropout=0.0,
|
137 |
+
proximal_bias=False,
|
138 |
+
proximal_init=False,
|
139 |
+
):
|
140 |
+
super(MultiHeadAttention, self).__init__()
|
141 |
+
assert channels % n_heads == 0
|
142 |
+
|
143 |
+
self.channels = channels
|
144 |
+
self.out_channels = out_channels
|
145 |
+
self.n_heads = n_heads
|
146 |
+
self.window_size = window_size
|
147 |
+
self.heads_share = heads_share
|
148 |
+
self.proximal_bias = proximal_bias
|
149 |
+
self.p_dropout = p_dropout
|
150 |
+
self.attn = None
|
151 |
+
|
152 |
+
self.k_channels = channels // n_heads
|
153 |
+
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
154 |
+
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
155 |
+
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
156 |
+
if window_size is not None:
|
157 |
+
n_heads_rel = 1 if heads_share else n_heads
|
158 |
+
rel_stddev = self.k_channels**-0.5
|
159 |
+
self.emb_rel_k = torch.nn.Parameter(
|
160 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
161 |
+
* rel_stddev
|
162 |
+
)
|
163 |
+
self.emb_rel_v = torch.nn.Parameter(
|
164 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
165 |
+
* rel_stddev
|
166 |
+
)
|
167 |
+
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
168 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
169 |
+
|
170 |
+
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
171 |
+
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
172 |
+
if proximal_init:
|
173 |
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
174 |
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
175 |
+
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
176 |
+
|
177 |
+
def forward(self, x, c, attn_mask=None):
|
178 |
+
q = self.conv_q(x)
|
179 |
+
k = self.conv_k(c)
|
180 |
+
v = self.conv_v(c)
|
181 |
+
|
182 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
183 |
+
|
184 |
+
x = self.conv_o(x)
|
185 |
+
return x
|
186 |
+
|
187 |
+
def attention(self, query, key, value, mask=None):
|
188 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
189 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
190 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
191 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
192 |
+
|
193 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
194 |
+
if self.window_size is not None:
|
195 |
+
assert (
|
196 |
+
t_s == t_t
|
197 |
+
), "Relative attention is only available for self-attention."
|
198 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
199 |
+
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
200 |
+
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
201 |
+
scores_local = rel_logits / math.sqrt(self.k_channels)
|
202 |
+
scores = scores + scores_local
|
203 |
+
if self.proximal_bias:
|
204 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
205 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
206 |
+
device=scores.device, dtype=scores.dtype
|
207 |
+
)
|
208 |
+
if mask is not None:
|
209 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
210 |
+
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
211 |
+
p_attn = self.drop(p_attn)
|
212 |
+
output = torch.matmul(p_attn, value)
|
213 |
+
if self.window_size is not None:
|
214 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
215 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
216 |
+
self.emb_rel_v, t_s
|
217 |
+
)
|
218 |
+
output = output + self._matmul_with_relative_values(
|
219 |
+
relative_weights, value_relative_embeddings
|
220 |
+
)
|
221 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
|
222 |
+
return output, p_attn
|
223 |
+
|
224 |
+
def _matmul_with_relative_values(self, x, y):
|
225 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
226 |
+
return ret
|
227 |
+
|
228 |
+
def _matmul_with_relative_keys(self, x, y):
|
229 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
230 |
+
return ret
|
231 |
+
|
232 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
233 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
234 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
235 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
236 |
+
if pad_length > 0:
|
237 |
+
padded_relative_embeddings = torch.nn.functional.pad(
|
238 |
+
relative_embeddings,
|
239 |
+
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
padded_relative_embeddings = relative_embeddings
|
243 |
+
used_relative_embeddings = padded_relative_embeddings[
|
244 |
+
:, slice_start_position:slice_end_position
|
245 |
+
]
|
246 |
+
return used_relative_embeddings
|
247 |
+
|
248 |
+
def _relative_position_to_absolute_position(self, x):
|
249 |
+
batch, heads, length, _ = x.size()
|
250 |
+
x = torch.nn.functional.pad(
|
251 |
+
x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])
|
252 |
+
)
|
253 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
254 |
+
x_flat = torch.nn.functional.pad(
|
255 |
+
x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
256 |
+
)
|
257 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
258 |
+
:, :, :length, length - 1 :
|
259 |
+
]
|
260 |
+
return x_final
|
261 |
+
|
262 |
+
def _absolute_position_to_relative_position(self, x):
|
263 |
+
batch, heads, length, _ = x.size()
|
264 |
+
x = torch.nn.functional.pad(
|
265 |
+
x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
266 |
+
)
|
267 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
268 |
+
x_flat = torch.nn.functional.pad(
|
269 |
+
x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])
|
270 |
+
)
|
271 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
272 |
+
return x_final
|
273 |
+
|
274 |
+
def _attention_bias_proximal(self, length):
|
275 |
+
r = torch.arange(length, dtype=torch.float32)
|
276 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
277 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
278 |
+
|
279 |
+
|
280 |
+
class FFN(BaseModule):
|
281 |
+
def __init__(
|
282 |
+
self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0
|
283 |
+
):
|
284 |
+
super(FFN, self).__init__()
|
285 |
+
self.in_channels = in_channels
|
286 |
+
self.out_channels = out_channels
|
287 |
+
self.filter_channels = filter_channels
|
288 |
+
self.kernel_size = kernel_size
|
289 |
+
self.p_dropout = p_dropout
|
290 |
+
|
291 |
+
self.conv_1 = torch.nn.Conv1d(
|
292 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
293 |
+
)
|
294 |
+
self.conv_2 = torch.nn.Conv1d(
|
295 |
+
filter_channels, out_channels, kernel_size, padding=kernel_size // 2
|
296 |
+
)
|
297 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
298 |
+
|
299 |
+
def forward(self, x, x_mask):
|
300 |
+
x = self.conv_1(x * x_mask)
|
301 |
+
x = torch.relu(x)
|
302 |
+
x = self.drop(x)
|
303 |
+
x = self.conv_2(x * x_mask)
|
304 |
+
return x * x_mask
|
305 |
+
|
306 |
+
|
307 |
+
class Encoder(BaseModule):
|
308 |
+
def __init__(
|
309 |
+
self,
|
310 |
+
hidden_channels,
|
311 |
+
filter_channels,
|
312 |
+
n_heads=2,
|
313 |
+
n_layers=6,
|
314 |
+
kernel_size=3,
|
315 |
+
p_dropout=0.1,
|
316 |
+
window_size=4,
|
317 |
+
**kwargs
|
318 |
+
):
|
319 |
+
super(Encoder, self).__init__()
|
320 |
+
self.hidden_channels = hidden_channels
|
321 |
+
self.filter_channels = filter_channels
|
322 |
+
self.n_heads = n_heads
|
323 |
+
self.n_layers = n_layers
|
324 |
+
self.kernel_size = kernel_size
|
325 |
+
self.p_dropout = p_dropout
|
326 |
+
self.window_size = window_size
|
327 |
+
|
328 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
329 |
+
self.attn_layers = torch.nn.ModuleList()
|
330 |
+
self.norm_layers_1 = torch.nn.ModuleList()
|
331 |
+
self.ffn_layers = torch.nn.ModuleList()
|
332 |
+
self.norm_layers_2 = torch.nn.ModuleList()
|
333 |
+
for _ in range(self.n_layers):
|
334 |
+
self.attn_layers.append(
|
335 |
+
MultiHeadAttention(
|
336 |
+
hidden_channels,
|
337 |
+
hidden_channels,
|
338 |
+
n_heads,
|
339 |
+
window_size=window_size,
|
340 |
+
p_dropout=p_dropout,
|
341 |
+
)
|
342 |
+
)
|
343 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
344 |
+
self.ffn_layers.append(
|
345 |
+
FFN(
|
346 |
+
hidden_channels,
|
347 |
+
hidden_channels,
|
348 |
+
filter_channels,
|
349 |
+
kernel_size,
|
350 |
+
p_dropout=p_dropout,
|
351 |
+
)
|
352 |
+
)
|
353 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
354 |
+
|
355 |
+
def forward(self, x, x_mask):
|
356 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
357 |
+
for i in range(self.n_layers):
|
358 |
+
x = x * x_mask
|
359 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
360 |
+
y = self.drop(y)
|
361 |
+
x = self.norm_layers_1[i](x + y)
|
362 |
+
y = self.ffn_layers[i](x, x_mask)
|
363 |
+
y = self.drop(y)
|
364 |
+
x = self.norm_layers_2[i](x + y)
|
365 |
+
x = x * x_mask
|
366 |
+
return x
|
367 |
+
|
368 |
+
|
369 |
+
class Conformer(BaseModule):
|
370 |
+
def __init__(self, cfg):
|
371 |
+
super().__init__()
|
372 |
+
self.cfg = cfg
|
373 |
+
self.n_heads = self.cfg.n_heads
|
374 |
+
self.n_layers = self.cfg.n_layers
|
375 |
+
self.hidden_channels = self.cfg.input_dim
|
376 |
+
self.filter_channels = self.cfg.filter_channels
|
377 |
+
self.output_dim = self.cfg.output_dim
|
378 |
+
self.dropout = self.cfg.dropout
|
379 |
+
|
380 |
+
self.conformer_encoder = Encoder(
|
381 |
+
self.hidden_channels,
|
382 |
+
self.filter_channels,
|
383 |
+
n_heads=self.n_heads,
|
384 |
+
n_layers=self.n_layers,
|
385 |
+
kernel_size=3,
|
386 |
+
p_dropout=self.dropout,
|
387 |
+
window_size=4,
|
388 |
+
)
|
389 |
+
self.projection = nn.Conv1d(self.hidden_channels, self.output_dim, 1)
|
390 |
+
|
391 |
+
def forward(self, x, x_mask):
|
392 |
+
"""
|
393 |
+
Args:
|
394 |
+
x: (N, seq_len, input_dim)
|
395 |
+
Returns:
|
396 |
+
output: (N, seq_len, output_dim)
|
397 |
+
"""
|
398 |
+
# (N, seq_len, d_model)
|
399 |
+
x = x.transpose(1, 2)
|
400 |
+
x_mask = x_mask.transpose(1, 2)
|
401 |
+
output = self.conformer_encoder(x, x_mask)
|
402 |
+
# (N, seq_len, output_dim)
|
403 |
+
output = self.projection(output)
|
404 |
+
output = output.transpose(1, 2)
|
405 |
+
return output
|
models/svc/transformer/transformer.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
10 |
+
|
11 |
+
|
12 |
+
class Transformer(nn.Module):
|
13 |
+
def __init__(self, cfg):
|
14 |
+
super().__init__()
|
15 |
+
self.cfg = cfg
|
16 |
+
|
17 |
+
dropout = self.cfg.dropout
|
18 |
+
nhead = self.cfg.n_heads
|
19 |
+
nlayers = self.cfg.n_layers
|
20 |
+
input_dim = self.cfg.input_dim
|
21 |
+
output_dim = self.cfg.output_dim
|
22 |
+
|
23 |
+
d_model = input_dim
|
24 |
+
self.pos_encoder = PositionalEncoding(d_model, dropout)
|
25 |
+
encoder_layers = TransformerEncoderLayer(
|
26 |
+
d_model, nhead, dropout=dropout, batch_first=True
|
27 |
+
)
|
28 |
+
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
|
29 |
+
|
30 |
+
self.output_mlp = nn.Linear(d_model, output_dim)
|
31 |
+
|
32 |
+
def forward(self, x, mask=None):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
x: (N, seq_len, input_dim)
|
36 |
+
Returns:
|
37 |
+
output: (N, seq_len, output_dim)
|
38 |
+
"""
|
39 |
+
# (N, seq_len, d_model)
|
40 |
+
src = self.pos_encoder(x)
|
41 |
+
# model_stats["pos_embedding"] = x
|
42 |
+
# (N, seq_len, d_model)
|
43 |
+
output = self.transformer_encoder(src)
|
44 |
+
# (N, seq_len, output_dim)
|
45 |
+
output = self.output_mlp(output)
|
46 |
+
return output
|
47 |
+
|
48 |
+
|
49 |
+
class PositionalEncoding(nn.Module):
|
50 |
+
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
51 |
+
super().__init__()
|
52 |
+
self.dropout = nn.Dropout(p=dropout)
|
53 |
+
|
54 |
+
position = torch.arange(max_len).unsqueeze(1)
|
55 |
+
div_term = torch.exp(
|
56 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
57 |
+
)
|
58 |
+
|
59 |
+
# Assume that x is (seq_len, N, d)
|
60 |
+
# pe = torch.zeros(max_len, 1, d_model)
|
61 |
+
# pe[:, 0, 0::2] = torch.sin(position * div_term)
|
62 |
+
# pe[:, 0, 1::2] = torch.cos(position * div_term)
|
63 |
+
|
64 |
+
# Assume that x in (N, seq_len, d)
|
65 |
+
pe = torch.zeros(1, max_len, d_model)
|
66 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
67 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
68 |
+
|
69 |
+
self.register_buffer("pe", pe)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
"""
|
73 |
+
Args:
|
74 |
+
x: Tensor, shape [N, seq_len, d]
|
75 |
+
"""
|
76 |
+
# Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model)
|
77 |
+
# x = x + self.pe[: x.size(0)]
|
78 |
+
|
79 |
+
# Now: self.pe is (1, max_len, d)
|
80 |
+
x = x + self.pe[:, : x.size(1), :]
|
81 |
+
|
82 |
+
return self.dropout(x)
|
models/svc/transformer/transformer_inference.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import torch.nn as nn
|
12 |
+
from collections import OrderedDict
|
13 |
+
|
14 |
+
from models.svc.base import SVCInference
|
15 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
16 |
+
from models.svc.transformer.transformer import Transformer
|
17 |
+
from models.svc.transformer.conformer import Conformer
|
18 |
+
|
19 |
+
|
20 |
+
class TransformerInference(SVCInference):
|
21 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
22 |
+
SVCInference.__init__(self, args, cfg, infer_type)
|
23 |
+
|
24 |
+
def _build_model(self):
|
25 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
26 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
27 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
28 |
+
if self.cfg.model.transformer.type == "transformer":
|
29 |
+
self.acoustic_mapper = Transformer(self.cfg.model.transformer)
|
30 |
+
elif self.cfg.model.transformer.type == "conformer":
|
31 |
+
self.acoustic_mapper = Conformer(self.cfg.model.transformer)
|
32 |
+
else:
|
33 |
+
raise NotImplementedError
|
34 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
35 |
+
return model
|
36 |
+
|
37 |
+
def _inference_each_batch(self, batch_data):
|
38 |
+
device = self.accelerator.device
|
39 |
+
for k, v in batch_data.items():
|
40 |
+
batch_data[k] = v.to(device)
|
41 |
+
|
42 |
+
condition = self.condition_encoder(batch_data)
|
43 |
+
y_pred = self.acoustic_mapper(condition, batch_data["mask"])
|
44 |
+
|
45 |
+
return y_pred
|
models/svc/transformer/transformer_trainer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from models.svc.base import SVCTrainer
|
9 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
10 |
+
from models.svc.transformer.transformer import Transformer
|
11 |
+
from models.svc.transformer.conformer import Conformer
|
12 |
+
from utils.ssim import SSIM
|
13 |
+
|
14 |
+
|
15 |
+
class TransformerTrainer(SVCTrainer):
|
16 |
+
def __init__(self, args, cfg):
|
17 |
+
SVCTrainer.__init__(self, args, cfg)
|
18 |
+
self.ssim_loss = SSIM()
|
19 |
+
|
20 |
+
def _build_model(self):
|
21 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
22 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
23 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
24 |
+
if self.cfg.model.transformer.type == "transformer":
|
25 |
+
self.acoustic_mapper = Transformer(self.cfg.model.transformer)
|
26 |
+
elif self.cfg.model.transformer.type == "conformer":
|
27 |
+
self.acoustic_mapper = Conformer(self.cfg.model.transformer)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError
|
30 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
31 |
+
return model
|
32 |
+
|
33 |
+
def _forward_step(self, batch):
|
34 |
+
total_loss = 0
|
35 |
+
device = self.accelerator.device
|
36 |
+
mel = batch["mel"]
|
37 |
+
mask = batch["mask"]
|
38 |
+
|
39 |
+
condition = self.condition_encoder(batch)
|
40 |
+
mel_pred = self.acoustic_mapper(condition, mask)
|
41 |
+
|
42 |
+
l1_loss = torch.sum(torch.abs(mel_pred - mel) * batch["mask"]) / torch.sum(
|
43 |
+
batch["mask"]
|
44 |
+
)
|
45 |
+
self._check_nan(l1_loss, mel_pred, mel)
|
46 |
+
total_loss += l1_loss
|
47 |
+
ssim_loss = self.ssim_loss(mel_pred, mel)
|
48 |
+
ssim_loss = torch.sum(ssim_loss * batch["mask"]) / torch.sum(batch["mask"])
|
49 |
+
self._check_nan(ssim_loss, mel_pred, mel)
|
50 |
+
total_loss += ssim_loss
|
51 |
+
|
52 |
+
return total_loss
|
models/svc/vits/__init__.py
ADDED
File without changes
|
models/svc/vits/vits.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# This code is modified from https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/models.py
|
7 |
+
import copy
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from utils.util import *
|
13 |
+
from utils.f0 import f0_to_coarse
|
14 |
+
|
15 |
+
from modules.transformer.attentions import Encoder
|
16 |
+
from models.tts.vits.vits import ResidualCouplingBlock, PosteriorEncoder
|
17 |
+
from models.vocoders.gan.generator.bigvgan import BigVGAN
|
18 |
+
from models.vocoders.gan.generator.hifigan import HiFiGAN
|
19 |
+
from models.vocoders.gan.generator.nsfhifigan import NSFHiFiGAN
|
20 |
+
from models.vocoders.gan.generator.melgan import MelGAN
|
21 |
+
from models.vocoders.gan.generator.apnet import APNet
|
22 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
23 |
+
|
24 |
+
|
25 |
+
def slice_pitch_segments(x, ids_str, segment_size=4):
|
26 |
+
ret = torch.zeros_like(x[:, :segment_size])
|
27 |
+
for i in range(x.size(0)):
|
28 |
+
idx_str = ids_str[i]
|
29 |
+
idx_end = idx_str + segment_size
|
30 |
+
ret[i] = x[i, idx_str:idx_end]
|
31 |
+
return ret
|
32 |
+
|
33 |
+
|
34 |
+
def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
|
35 |
+
b, d, t = x.size()
|
36 |
+
if x_lengths is None:
|
37 |
+
x_lengths = t
|
38 |
+
ids_str_max = x_lengths - segment_size + 1
|
39 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
40 |
+
ret = slice_segments(x, ids_str, segment_size)
|
41 |
+
ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size)
|
42 |
+
return ret, ret_pitch, ids_str
|
43 |
+
|
44 |
+
|
45 |
+
class ContentEncoder(nn.Module):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
out_channels,
|
49 |
+
hidden_channels,
|
50 |
+
kernel_size,
|
51 |
+
n_layers,
|
52 |
+
gin_channels=0,
|
53 |
+
filter_channels=None,
|
54 |
+
n_heads=None,
|
55 |
+
p_dropout=None,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.out_channels = out_channels
|
59 |
+
self.hidden_channels = hidden_channels
|
60 |
+
self.kernel_size = kernel_size
|
61 |
+
self.n_layers = n_layers
|
62 |
+
self.gin_channels = gin_channels
|
63 |
+
|
64 |
+
self.f0_emb = nn.Embedding(256, hidden_channels)
|
65 |
+
|
66 |
+
self.enc_ = Encoder(
|
67 |
+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
68 |
+
)
|
69 |
+
|
70 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
71 |
+
|
72 |
+
# condition_encoder ver.
|
73 |
+
def forward(self, x, x_mask, noice_scale=1):
|
74 |
+
x = self.enc_(x * x_mask, x_mask)
|
75 |
+
stats = self.proj(x) * x_mask
|
76 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
77 |
+
z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask
|
78 |
+
|
79 |
+
return z, m, logs, x_mask
|
80 |
+
|
81 |
+
|
82 |
+
class SynthesizerTrn(nn.Module):
|
83 |
+
"""
|
84 |
+
Synthesizer for Training
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, spec_channels, segment_size, cfg):
|
88 |
+
super().__init__()
|
89 |
+
self.spec_channels = spec_channels
|
90 |
+
self.segment_size = segment_size
|
91 |
+
self.cfg = cfg
|
92 |
+
self.inter_channels = cfg.model.vits.inter_channels
|
93 |
+
self.hidden_channels = cfg.model.vits.hidden_channels
|
94 |
+
self.filter_channels = cfg.model.vits.filter_channels
|
95 |
+
self.n_heads = cfg.model.vits.n_heads
|
96 |
+
self.n_layers = cfg.model.vits.n_layers
|
97 |
+
self.kernel_size = cfg.model.vits.kernel_size
|
98 |
+
self.p_dropout = cfg.model.vits.p_dropout
|
99 |
+
self.ssl_dim = cfg.model.vits.ssl_dim
|
100 |
+
self.n_flow_layer = cfg.model.vits.n_flow_layer
|
101 |
+
self.gin_channels = cfg.model.vits.gin_channels
|
102 |
+
self.n_speakers = cfg.model.vits.n_speakers
|
103 |
+
|
104 |
+
# f0
|
105 |
+
self.n_bins = cfg.preprocess.pitch_bin
|
106 |
+
self.f0_min = cfg.preprocess.f0_min
|
107 |
+
self.f0_max = cfg.preprocess.f0_max
|
108 |
+
|
109 |
+
# TODO: sort out the config
|
110 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
111 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
112 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
113 |
+
|
114 |
+
self.emb_g = nn.Embedding(self.n_speakers, self.gin_channels)
|
115 |
+
|
116 |
+
self.enc_p = ContentEncoder(
|
117 |
+
self.inter_channels,
|
118 |
+
self.hidden_channels,
|
119 |
+
filter_channels=self.filter_channels,
|
120 |
+
n_heads=self.n_heads,
|
121 |
+
n_layers=self.n_layers,
|
122 |
+
kernel_size=self.kernel_size,
|
123 |
+
p_dropout=self.p_dropout,
|
124 |
+
)
|
125 |
+
|
126 |
+
assert cfg.model.generator in [
|
127 |
+
"bigvgan",
|
128 |
+
"hifigan",
|
129 |
+
"melgan",
|
130 |
+
"nsfhifigan",
|
131 |
+
"apnet",
|
132 |
+
]
|
133 |
+
self.dec_name = cfg.model.generator
|
134 |
+
temp_cfg = copy.deepcopy(cfg)
|
135 |
+
temp_cfg.preprocess.n_mel = self.inter_channels
|
136 |
+
if cfg.model.generator == "bigvgan":
|
137 |
+
temp_cfg.model.bigvgan = cfg.model.generator_config.bigvgan
|
138 |
+
self.dec = BigVGAN(temp_cfg)
|
139 |
+
elif cfg.model.generator == "hifigan":
|
140 |
+
temp_cfg.model.hifigan = cfg.model.generator_config.hifigan
|
141 |
+
self.dec = HiFiGAN(temp_cfg)
|
142 |
+
elif cfg.model.generator == "melgan":
|
143 |
+
temp_cfg.model.melgan = cfg.model.generator_config.melgan
|
144 |
+
self.dec = MelGAN(temp_cfg)
|
145 |
+
elif cfg.model.generator == "nsfhifigan":
|
146 |
+
temp_cfg.model.nsfhifigan = cfg.model.generator_config.nsfhifigan
|
147 |
+
self.dec = NSFHiFiGAN(temp_cfg) # TODO: nsf need f0
|
148 |
+
elif cfg.model.generator == "apnet":
|
149 |
+
temp_cfg.model.apnet = cfg.model.generator_config.apnet
|
150 |
+
self.dec = APNet(temp_cfg)
|
151 |
+
|
152 |
+
self.enc_q = PosteriorEncoder(
|
153 |
+
self.spec_channels,
|
154 |
+
self.inter_channels,
|
155 |
+
self.hidden_channels,
|
156 |
+
5,
|
157 |
+
1,
|
158 |
+
16,
|
159 |
+
gin_channels=self.gin_channels,
|
160 |
+
)
|
161 |
+
|
162 |
+
self.flow = ResidualCouplingBlock(
|
163 |
+
self.inter_channels,
|
164 |
+
self.hidden_channels,
|
165 |
+
5,
|
166 |
+
1,
|
167 |
+
self.n_flow_layer,
|
168 |
+
gin_channels=self.gin_channels,
|
169 |
+
)
|
170 |
+
|
171 |
+
def forward(self, data):
|
172 |
+
"""VitsSVC forward function.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
data (dict): condition data & audio data, including:
|
176 |
+
B: batch size, T: target length
|
177 |
+
{
|
178 |
+
"spk_id": [B, singer_table_size]
|
179 |
+
"target_len": [B]
|
180 |
+
"mask": [B, T, 1]
|
181 |
+
"mel": [B, T, n_mel]
|
182 |
+
"linear": [B, T, n_fft // 2 + 1]
|
183 |
+
"frame_pitch": [B, T]
|
184 |
+
"frame_uv": [B, T]
|
185 |
+
"audio": [B, audio_len]
|
186 |
+
"audio_len": [B]
|
187 |
+
"contentvec_feat": [B, T, contentvec_dim]
|
188 |
+
"whisper_feat": [B, T, whisper_dim]
|
189 |
+
...
|
190 |
+
}
|
191 |
+
"""
|
192 |
+
|
193 |
+
# TODO: elegantly handle the dimensions
|
194 |
+
c = data["contentvec_feat"].transpose(1, 2)
|
195 |
+
spec = data["linear"].transpose(1, 2)
|
196 |
+
|
197 |
+
g = data["spk_id"]
|
198 |
+
g = self.emb_g(g).transpose(1, 2)
|
199 |
+
|
200 |
+
c_lengths = data["target_len"]
|
201 |
+
spec_lengths = data["target_len"]
|
202 |
+
f0 = data["frame_pitch"]
|
203 |
+
|
204 |
+
x_mask = torch.unsqueeze(sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
|
205 |
+
# condition_encoder ver.
|
206 |
+
x = self.condition_encoder(data).transpose(1, 2)
|
207 |
+
|
208 |
+
# prior encoder
|
209 |
+
z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask)
|
210 |
+
# posterior encoder
|
211 |
+
z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
|
212 |
+
|
213 |
+
# flow
|
214 |
+
z_p = self.flow(z, spec_mask, g=g)
|
215 |
+
z_slice, pitch_slice, ids_slice = rand_slice_segments_with_pitch(
|
216 |
+
z, f0, spec_lengths, self.segment_size
|
217 |
+
)
|
218 |
+
|
219 |
+
if self.dec_name == "nsfhifigan":
|
220 |
+
o = self.dec(z_slice, f0=f0.float())
|
221 |
+
elif self.dec_name == "apnet":
|
222 |
+
_, _, _, _, o = self.dec(z_slice)
|
223 |
+
else:
|
224 |
+
o = self.dec(z_slice)
|
225 |
+
|
226 |
+
outputs = {
|
227 |
+
"y_hat": o,
|
228 |
+
"ids_slice": ids_slice,
|
229 |
+
"x_mask": x_mask,
|
230 |
+
"z_mask": data["mask"].transpose(1, 2),
|
231 |
+
"z": z,
|
232 |
+
"z_p": z_p,
|
233 |
+
"m_p": m_p,
|
234 |
+
"logs_p": logs_p,
|
235 |
+
"m_q": m_q,
|
236 |
+
"logs_q": logs_q,
|
237 |
+
}
|
238 |
+
return outputs
|
239 |
+
|
240 |
+
@torch.no_grad()
|
241 |
+
def infer(self, data, noise_scale=0.35, seed=52468):
|
242 |
+
# c, f0, uv, g
|
243 |
+
c = data["contentvec_feat"].transpose(1, 2)
|
244 |
+
f0 = data["frame_pitch"]
|
245 |
+
g = data["spk_id"]
|
246 |
+
|
247 |
+
if c.device == torch.device("cuda"):
|
248 |
+
torch.cuda.manual_seed_all(seed)
|
249 |
+
else:
|
250 |
+
torch.manual_seed(seed)
|
251 |
+
|
252 |
+
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
253 |
+
|
254 |
+
if g.dim() == 1:
|
255 |
+
g = g.unsqueeze(0)
|
256 |
+
g = self.emb_g(g).transpose(1, 2)
|
257 |
+
|
258 |
+
x_mask = torch.unsqueeze(sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
|
259 |
+
# condition_encoder ver.
|
260 |
+
x = self.condition_encoder(data).transpose(1, 2)
|
261 |
+
|
262 |
+
z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, noice_scale=noise_scale)
|
263 |
+
z = self.flow(z_p, c_mask, g=g, reverse=True)
|
264 |
+
|
265 |
+
if self.dec_name == "nsfhifigan":
|
266 |
+
o = self.dec(z * c_mask, f0=f0)
|
267 |
+
elif self.dec_name == "apnet":
|
268 |
+
_, _, _, _, o = self.dec(z * c_mask)
|
269 |
+
else:
|
270 |
+
o = self.dec(z * c_mask)
|
271 |
+
return o, f0
|
models/svc/vits/vits_inference.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from models.svc.base import SVCInference
|
14 |
+
from models.svc.vits.vits import SynthesizerTrn
|
15 |
+
|
16 |
+
from models.svc.base.svc_dataset import SVCTestDataset, SVCTestCollator
|
17 |
+
from utils.io import save_audio
|
18 |
+
from utils.audio_slicer import is_silence
|
19 |
+
|
20 |
+
|
21 |
+
class VitsInference(SVCInference):
|
22 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
23 |
+
SVCInference.__init__(self, args, cfg)
|
24 |
+
|
25 |
+
def _build_model(self):
|
26 |
+
net_g = SynthesizerTrn(
|
27 |
+
self.cfg.preprocess.n_fft // 2 + 1,
|
28 |
+
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
|
29 |
+
self.cfg,
|
30 |
+
)
|
31 |
+
self.model = net_g
|
32 |
+
return net_g
|
33 |
+
|
34 |
+
def build_save_dir(self, dataset, speaker):
|
35 |
+
save_dir = os.path.join(
|
36 |
+
self.args.output_dir,
|
37 |
+
"svc_am_step-{}_{}".format(self.am_restore_step, self.args.mode),
|
38 |
+
)
|
39 |
+
if dataset is not None:
|
40 |
+
save_dir = os.path.join(save_dir, "data_{}".format(dataset))
|
41 |
+
if speaker != -1:
|
42 |
+
save_dir = os.path.join(
|
43 |
+
save_dir,
|
44 |
+
"spk_{}".format(speaker),
|
45 |
+
)
|
46 |
+
os.makedirs(save_dir, exist_ok=True)
|
47 |
+
print("Saving to ", save_dir)
|
48 |
+
return save_dir
|
49 |
+
|
50 |
+
@torch.inference_mode()
|
51 |
+
def inference(self):
|
52 |
+
res = []
|
53 |
+
for i, batch in enumerate(self.test_dataloader):
|
54 |
+
pred_audio_list = self._inference_each_batch(batch)
|
55 |
+
for it, wav in zip(self.test_dataset.metadata, pred_audio_list):
|
56 |
+
uid = it["Uid"]
|
57 |
+
file = os.path.join(self.args.output_dir, f"{uid}.wav")
|
58 |
+
|
59 |
+
wav = wav.numpy(force=True)
|
60 |
+
save_audio(
|
61 |
+
file,
|
62 |
+
wav,
|
63 |
+
self.cfg.preprocess.sample_rate,
|
64 |
+
add_silence=False,
|
65 |
+
turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
|
66 |
+
)
|
67 |
+
res.append(file)
|
68 |
+
return res
|
69 |
+
|
70 |
+
def _inference_each_batch(self, batch_data, noise_scale=0.667):
|
71 |
+
device = self.accelerator.device
|
72 |
+
pred_res = []
|
73 |
+
self.model.eval()
|
74 |
+
with torch.no_grad():
|
75 |
+
# Put the data to device
|
76 |
+
# device = self.accelerator.device
|
77 |
+
for k, v in batch_data.items():
|
78 |
+
batch_data[k] = v.to(device)
|
79 |
+
|
80 |
+
audios, f0 = self.model.infer(batch_data, noise_scale=noise_scale)
|
81 |
+
|
82 |
+
pred_res.extend(audios)
|
83 |
+
|
84 |
+
return pred_res
|
models/svc/vits/vits_trainer.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.optim.lr_scheduler import ExponentialLR
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
# from models.svc.base import SVCTrainer
|
11 |
+
from models.svc.base.svc_dataset import SVCCollator, SVCDataset
|
12 |
+
from models.svc.vits.vits import *
|
13 |
+
from models.tts.base import TTSTrainer
|
14 |
+
|
15 |
+
from utils.mel import mel_spectrogram_torch
|
16 |
+
import json
|
17 |
+
|
18 |
+
from models.vocoders.gan.discriminator.mpd import (
|
19 |
+
MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class VitsSVCTrainer(TTSTrainer):
|
24 |
+
def __init__(self, args, cfg):
|
25 |
+
self.args = args
|
26 |
+
self.cfg = cfg
|
27 |
+
self._init_accelerator()
|
28 |
+
# Only for SVC tasks
|
29 |
+
with self.accelerator.main_process_first():
|
30 |
+
self.singers = self._build_singer_lut()
|
31 |
+
TTSTrainer.__init__(self, args, cfg)
|
32 |
+
|
33 |
+
def _build_model(self):
|
34 |
+
net_g = SynthesizerTrn(
|
35 |
+
self.cfg.preprocess.n_fft // 2 + 1,
|
36 |
+
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
|
37 |
+
# directly use cfg
|
38 |
+
self.cfg,
|
39 |
+
)
|
40 |
+
net_d = MultiPeriodDiscriminator(self.cfg.model.vits.use_spectral_norm)
|
41 |
+
model = {"generator": net_g, "discriminator": net_d}
|
42 |
+
|
43 |
+
return model
|
44 |
+
|
45 |
+
def _build_dataset(self):
|
46 |
+
return SVCDataset, SVCCollator
|
47 |
+
|
48 |
+
def _build_optimizer(self):
|
49 |
+
optimizer_g = torch.optim.AdamW(
|
50 |
+
self.model["generator"].parameters(),
|
51 |
+
self.cfg.train.learning_rate,
|
52 |
+
betas=self.cfg.train.AdamW.betas,
|
53 |
+
eps=self.cfg.train.AdamW.eps,
|
54 |
+
)
|
55 |
+
optimizer_d = torch.optim.AdamW(
|
56 |
+
self.model["discriminator"].parameters(),
|
57 |
+
self.cfg.train.learning_rate,
|
58 |
+
betas=self.cfg.train.AdamW.betas,
|
59 |
+
eps=self.cfg.train.AdamW.eps,
|
60 |
+
)
|
61 |
+
optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d}
|
62 |
+
|
63 |
+
return optimizer
|
64 |
+
|
65 |
+
def _build_scheduler(self):
|
66 |
+
scheduler_g = ExponentialLR(
|
67 |
+
self.optimizer["optimizer_g"],
|
68 |
+
gamma=self.cfg.train.lr_decay,
|
69 |
+
last_epoch=self.epoch - 1,
|
70 |
+
)
|
71 |
+
scheduler_d = ExponentialLR(
|
72 |
+
self.optimizer["optimizer_d"],
|
73 |
+
gamma=self.cfg.train.lr_decay,
|
74 |
+
last_epoch=self.epoch - 1,
|
75 |
+
)
|
76 |
+
|
77 |
+
scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d}
|
78 |
+
return scheduler
|
79 |
+
|
80 |
+
def _build_criterion(self):
|
81 |
+
class GeneratorLoss(nn.Module):
|
82 |
+
def __init__(self, cfg):
|
83 |
+
super(GeneratorLoss, self).__init__()
|
84 |
+
self.cfg = cfg
|
85 |
+
self.l1_loss = nn.L1Loss()
|
86 |
+
|
87 |
+
def generator_loss(self, disc_outputs):
|
88 |
+
loss = 0
|
89 |
+
gen_losses = []
|
90 |
+
for dg in disc_outputs:
|
91 |
+
dg = dg.float()
|
92 |
+
l = torch.mean((1 - dg) ** 2)
|
93 |
+
gen_losses.append(l)
|
94 |
+
loss += l
|
95 |
+
|
96 |
+
return loss, gen_losses
|
97 |
+
|
98 |
+
def feature_loss(self, fmap_r, fmap_g):
|
99 |
+
loss = 0
|
100 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
101 |
+
for rl, gl in zip(dr, dg):
|
102 |
+
rl = rl.float().detach()
|
103 |
+
gl = gl.float()
|
104 |
+
loss += torch.mean(torch.abs(rl - gl))
|
105 |
+
|
106 |
+
return loss * 2
|
107 |
+
|
108 |
+
def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask):
|
109 |
+
"""
|
110 |
+
z_p, logs_q: [b, h, t_t]
|
111 |
+
m_p, logs_p: [b, h, t_t]
|
112 |
+
"""
|
113 |
+
z_p = z_p.float()
|
114 |
+
logs_q = logs_q.float()
|
115 |
+
m_p = m_p.float()
|
116 |
+
logs_p = logs_p.float()
|
117 |
+
z_mask = z_mask.float()
|
118 |
+
|
119 |
+
kl = logs_p - logs_q - 0.5
|
120 |
+
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
121 |
+
kl = torch.sum(kl * z_mask)
|
122 |
+
l = kl / torch.sum(z_mask)
|
123 |
+
return l
|
124 |
+
|
125 |
+
def forward(
|
126 |
+
self,
|
127 |
+
outputs_g,
|
128 |
+
outputs_d,
|
129 |
+
y_mel,
|
130 |
+
y_hat_mel,
|
131 |
+
):
|
132 |
+
loss_g = {}
|
133 |
+
|
134 |
+
# mel loss
|
135 |
+
loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel
|
136 |
+
loss_g["loss_mel"] = loss_mel
|
137 |
+
|
138 |
+
# kl loss
|
139 |
+
loss_kl = (
|
140 |
+
self.kl_loss(
|
141 |
+
outputs_g["z_p"],
|
142 |
+
outputs_g["logs_q"],
|
143 |
+
outputs_g["m_p"],
|
144 |
+
outputs_g["logs_p"],
|
145 |
+
outputs_g["z_mask"],
|
146 |
+
)
|
147 |
+
* self.cfg.train.c_kl
|
148 |
+
)
|
149 |
+
loss_g["loss_kl"] = loss_kl
|
150 |
+
|
151 |
+
# feature loss
|
152 |
+
loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"])
|
153 |
+
loss_g["loss_fm"] = loss_fm
|
154 |
+
|
155 |
+
# gan loss
|
156 |
+
loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"])
|
157 |
+
loss_g["loss_gen"] = loss_gen
|
158 |
+
loss_g["loss_gen_all"] = loss_mel + loss_kl + loss_fm + loss_gen
|
159 |
+
|
160 |
+
return loss_g
|
161 |
+
|
162 |
+
class DiscriminatorLoss(nn.Module):
|
163 |
+
def __init__(self, cfg):
|
164 |
+
super(DiscriminatorLoss, self).__init__()
|
165 |
+
self.cfg = cfg
|
166 |
+
self.l1Loss = torch.nn.L1Loss(reduction="mean")
|
167 |
+
|
168 |
+
def __call__(self, disc_real_outputs, disc_generated_outputs):
|
169 |
+
loss_d = {}
|
170 |
+
|
171 |
+
loss = 0
|
172 |
+
r_losses = []
|
173 |
+
g_losses = []
|
174 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
175 |
+
dr = dr.float()
|
176 |
+
dg = dg.float()
|
177 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
178 |
+
g_loss = torch.mean(dg**2)
|
179 |
+
loss += r_loss + g_loss
|
180 |
+
r_losses.append(r_loss.item())
|
181 |
+
g_losses.append(g_loss.item())
|
182 |
+
|
183 |
+
loss_d["loss_disc_all"] = loss
|
184 |
+
|
185 |
+
return loss_d
|
186 |
+
|
187 |
+
criterion = {
|
188 |
+
"generator": GeneratorLoss(self.cfg),
|
189 |
+
"discriminator": DiscriminatorLoss(self.cfg),
|
190 |
+
}
|
191 |
+
return criterion
|
192 |
+
|
193 |
+
# Keep legacy unchanged
|
194 |
+
def write_summary(
|
195 |
+
self,
|
196 |
+
losses,
|
197 |
+
stats,
|
198 |
+
images={},
|
199 |
+
audios={},
|
200 |
+
audio_sampling_rate=24000,
|
201 |
+
tag="train",
|
202 |
+
):
|
203 |
+
for key, value in losses.items():
|
204 |
+
self.sw.add_scalar(tag + "/" + key, value, self.step)
|
205 |
+
self.sw.add_scalar(
|
206 |
+
"learning_rate",
|
207 |
+
self.optimizer["optimizer_g"].param_groups[0]["lr"],
|
208 |
+
self.step,
|
209 |
+
)
|
210 |
+
|
211 |
+
if len(images) != 0:
|
212 |
+
for key, value in images.items():
|
213 |
+
self.sw.add_image(key, value, self.global_step, batchformats="HWC")
|
214 |
+
if len(audios) != 0:
|
215 |
+
for key, value in audios.items():
|
216 |
+
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
|
217 |
+
|
218 |
+
def write_valid_summary(
|
219 |
+
self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val"
|
220 |
+
):
|
221 |
+
for key, value in losses.items():
|
222 |
+
self.sw.add_scalar(tag + "/" + key, value, self.step)
|
223 |
+
|
224 |
+
if len(images) != 0:
|
225 |
+
for key, value in images.items():
|
226 |
+
self.sw.add_image(key, value, self.global_step, batchformats="HWC")
|
227 |
+
if len(audios) != 0:
|
228 |
+
for key, value in audios.items():
|
229 |
+
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
|
230 |
+
|
231 |
+
def _get_state_dict(self):
|
232 |
+
state_dict = {
|
233 |
+
"generator": self.model["generator"].state_dict(),
|
234 |
+
"discriminator": self.model["discriminator"].state_dict(),
|
235 |
+
"optimizer_g": self.optimizer["optimizer_g"].state_dict(),
|
236 |
+
"optimizer_d": self.optimizer["optimizer_d"].state_dict(),
|
237 |
+
"scheduler_g": self.scheduler["scheduler_g"].state_dict(),
|
238 |
+
"scheduler_d": self.scheduler["scheduler_d"].state_dict(),
|
239 |
+
"step": self.step,
|
240 |
+
"epoch": self.epoch,
|
241 |
+
"batch_size": self.cfg.train.batch_size,
|
242 |
+
}
|
243 |
+
return state_dict
|
244 |
+
|
245 |
+
def get_state_dict(self):
|
246 |
+
state_dict = {
|
247 |
+
"generator": self.model["generator"].state_dict(),
|
248 |
+
"discriminator": self.model["discriminator"].state_dict(),
|
249 |
+
"optimizer_g": self.optimizer["optimizer_g"].state_dict(),
|
250 |
+
"optimizer_d": self.optimizer["optimizer_d"].state_dict(),
|
251 |
+
"scheduler_g": self.scheduler["scheduler_g"].state_dict(),
|
252 |
+
"scheduler_d": self.scheduler["scheduler_d"].state_dict(),
|
253 |
+
"step": self.step,
|
254 |
+
"epoch": self.epoch,
|
255 |
+
"batch_size": self.cfg.train.batch_size,
|
256 |
+
}
|
257 |
+
return state_dict
|
258 |
+
|
259 |
+
def load_model(self, checkpoint):
|
260 |
+
self.step = checkpoint["step"]
|
261 |
+
self.epoch = checkpoint["epoch"]
|
262 |
+
self.model["generator"].load_state_dict(checkpoint["generator"])
|
263 |
+
self.model["discriminator"].load_state_dict(checkpoint["discriminator"])
|
264 |
+
self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"])
|
265 |
+
self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"])
|
266 |
+
self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"])
|
267 |
+
self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"])
|
268 |
+
|
269 |
+
@torch.inference_mode()
|
270 |
+
def _valid_step(self, batch):
|
271 |
+
r"""Testing forward step. Should return average loss of a sample over
|
272 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
273 |
+
See ``_test_epoch`` for usage.
|
274 |
+
"""
|
275 |
+
|
276 |
+
valid_losses = {}
|
277 |
+
total_loss = 0
|
278 |
+
valid_stats = {}
|
279 |
+
|
280 |
+
# Discriminator
|
281 |
+
# Generator output
|
282 |
+
outputs_g = self.model["generator"](batch)
|
283 |
+
|
284 |
+
y_mel = slice_segments(
|
285 |
+
batch["mel"].transpose(1, 2),
|
286 |
+
outputs_g["ids_slice"],
|
287 |
+
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
|
288 |
+
)
|
289 |
+
y_hat_mel = mel_spectrogram_torch(
|
290 |
+
outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
|
291 |
+
)
|
292 |
+
y = slice_segments(
|
293 |
+
batch["audio"].unsqueeze(1),
|
294 |
+
outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
|
295 |
+
self.cfg.preprocess.segment_size,
|
296 |
+
)
|
297 |
+
|
298 |
+
# Discriminator output
|
299 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
|
300 |
+
## Discriminator loss
|
301 |
+
loss_d = self.criterion["discriminator"](
|
302 |
+
outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
|
303 |
+
)
|
304 |
+
valid_losses.update(loss_d)
|
305 |
+
|
306 |
+
## Generator
|
307 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
|
308 |
+
loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
|
309 |
+
valid_losses.update(loss_g)
|
310 |
+
|
311 |
+
for item in valid_losses:
|
312 |
+
valid_losses[item] = valid_losses[item].item()
|
313 |
+
|
314 |
+
total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
|
315 |
+
|
316 |
+
return (
|
317 |
+
total_loss.item(),
|
318 |
+
valid_losses,
|
319 |
+
valid_stats,
|
320 |
+
)
|
321 |
+
|
322 |
+
def _train_step(self, batch):
|
323 |
+
r"""Forward step for training and inference. This function is called
|
324 |
+
in ``_train_step`` & ``_test_step`` function.
|
325 |
+
"""
|
326 |
+
|
327 |
+
train_losses = {}
|
328 |
+
total_loss = 0
|
329 |
+
training_stats = {}
|
330 |
+
|
331 |
+
## Train Discriminator
|
332 |
+
# Generator output
|
333 |
+
outputs_g = self.model["generator"](batch)
|
334 |
+
|
335 |
+
y_mel = slice_segments(
|
336 |
+
batch["mel"].transpose(1, 2),
|
337 |
+
outputs_g["ids_slice"],
|
338 |
+
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
|
339 |
+
)
|
340 |
+
y_hat_mel = mel_spectrogram_torch(
|
341 |
+
outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
|
342 |
+
)
|
343 |
+
|
344 |
+
y = slice_segments(
|
345 |
+
# [1, 168418] -> [1, 1, 168418]
|
346 |
+
batch["audio"].unsqueeze(1),
|
347 |
+
outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
|
348 |
+
self.cfg.preprocess.segment_size,
|
349 |
+
)
|
350 |
+
|
351 |
+
# Discriminator output
|
352 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
|
353 |
+
# Discriminator loss
|
354 |
+
loss_d = self.criterion["discriminator"](
|
355 |
+
outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
|
356 |
+
)
|
357 |
+
train_losses.update(loss_d)
|
358 |
+
|
359 |
+
# BP and Grad Updated
|
360 |
+
self.optimizer["optimizer_d"].zero_grad()
|
361 |
+
self.accelerator.backward(loss_d["loss_disc_all"])
|
362 |
+
self.optimizer["optimizer_d"].step()
|
363 |
+
|
364 |
+
## Train Generator
|
365 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
|
366 |
+
loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
|
367 |
+
train_losses.update(loss_g)
|
368 |
+
|
369 |
+
# BP and Grad Updated
|
370 |
+
self.optimizer["optimizer_g"].zero_grad()
|
371 |
+
self.accelerator.backward(loss_g["loss_gen_all"])
|
372 |
+
self.optimizer["optimizer_g"].step()
|
373 |
+
|
374 |
+
for item in train_losses:
|
375 |
+
train_losses[item] = train_losses[item].item()
|
376 |
+
|
377 |
+
total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
|
378 |
+
|
379 |
+
return (
|
380 |
+
total_loss.item(),
|
381 |
+
train_losses,
|
382 |
+
training_stats,
|
383 |
+
)
|
384 |
+
|
385 |
+
def _train_epoch(self):
|
386 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
387 |
+
one epoch. See ``train_loop`` for usage.
|
388 |
+
"""
|
389 |
+
epoch_sum_loss: float = 0.0
|
390 |
+
epoch_losses: dict = {}
|
391 |
+
epoch_step: int = 0
|
392 |
+
for batch in tqdm(
|
393 |
+
self.train_dataloader,
|
394 |
+
desc=f"Training Epoch {self.epoch}",
|
395 |
+
unit="batch",
|
396 |
+
colour="GREEN",
|
397 |
+
leave=False,
|
398 |
+
dynamic_ncols=True,
|
399 |
+
smoothing=0.04,
|
400 |
+
disable=not self.accelerator.is_main_process,
|
401 |
+
):
|
402 |
+
# Do training step and BP
|
403 |
+
with self.accelerator.accumulate(self.model):
|
404 |
+
total_loss, train_losses, training_stats = self._train_step(batch)
|
405 |
+
self.batch_count += 1
|
406 |
+
|
407 |
+
# Update info for each step
|
408 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
409 |
+
epoch_sum_loss += total_loss
|
410 |
+
for key, value in train_losses.items():
|
411 |
+
if key not in epoch_losses.keys():
|
412 |
+
epoch_losses[key] = value
|
413 |
+
else:
|
414 |
+
epoch_losses[key] += value
|
415 |
+
|
416 |
+
self.accelerator.log(
|
417 |
+
{
|
418 |
+
"Step/Generator Loss": train_losses["loss_gen_all"],
|
419 |
+
"Step/Discriminator Loss": train_losses["loss_disc_all"],
|
420 |
+
"Step/Generator Learning Rate": self.optimizer[
|
421 |
+
"optimizer_d"
|
422 |
+
].param_groups[0]["lr"],
|
423 |
+
"Step/Discriminator Learning Rate": self.optimizer[
|
424 |
+
"optimizer_g"
|
425 |
+
].param_groups[0]["lr"],
|
426 |
+
},
|
427 |
+
step=self.step,
|
428 |
+
)
|
429 |
+
self.step += 1
|
430 |
+
epoch_step += 1
|
431 |
+
|
432 |
+
self.accelerator.wait_for_everyone()
|
433 |
+
|
434 |
+
epoch_sum_loss = (
|
435 |
+
epoch_sum_loss
|
436 |
+
/ len(self.train_dataloader)
|
437 |
+
* self.cfg.train.gradient_accumulation_step
|
438 |
+
)
|
439 |
+
|
440 |
+
for key in epoch_losses.keys():
|
441 |
+
epoch_losses[key] = (
|
442 |
+
epoch_losses[key]
|
443 |
+
/ len(self.train_dataloader)
|
444 |
+
* self.cfg.train.gradient_accumulation_step
|
445 |
+
)
|
446 |
+
|
447 |
+
return epoch_sum_loss, epoch_losses
|
448 |
+
|
449 |
+
def _build_singer_lut(self):
|
450 |
+
resumed_singer_path = None
|
451 |
+
if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "":
|
452 |
+
resumed_singer_path = os.path.join(
|
453 |
+
self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id
|
454 |
+
)
|
455 |
+
if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
|
456 |
+
resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
|
457 |
+
|
458 |
+
if resumed_singer_path:
|
459 |
+
with open(resumed_singer_path, "r") as f:
|
460 |
+
singers = json.load(f)
|
461 |
+
else:
|
462 |
+
singers = dict()
|
463 |
+
|
464 |
+
for dataset in self.cfg.dataset:
|
465 |
+
singer_lut_path = os.path.join(
|
466 |
+
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
|
467 |
+
)
|
468 |
+
with open(singer_lut_path, "r") as singer_lut_path:
|
469 |
+
singer_lut = json.load(singer_lut_path)
|
470 |
+
for singer in singer_lut.keys():
|
471 |
+
if singer not in singers:
|
472 |
+
singers[singer] = len(singers)
|
473 |
+
|
474 |
+
with open(
|
475 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
|
476 |
+
) as singer_file:
|
477 |
+
json.dump(singers, singer_file, indent=4, ensure_ascii=False)
|
478 |
+
print(
|
479 |
+
"singers have been dumped to {}".format(
|
480 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
|
481 |
+
)
|
482 |
+
)
|
483 |
+
return singers
|
models/tta/autoencoder/__init__.py
ADDED
File without changes
|
models/tta/autoencoder/autoencoder.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from modules.distributions.distributions import DiagonalGaussianDistribution
|
11 |
+
|
12 |
+
|
13 |
+
def nonlinearity(x):
|
14 |
+
# swish
|
15 |
+
return x * torch.sigmoid(x)
|
16 |
+
|
17 |
+
|
18 |
+
def Normalize(in_channels):
|
19 |
+
return torch.nn.GroupNorm(
|
20 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class Upsample2d(nn.Module):
|
25 |
+
def __init__(self, in_channels, with_conv):
|
26 |
+
super().__init__()
|
27 |
+
self.with_conv = with_conv
|
28 |
+
if self.with_conv:
|
29 |
+
self.conv = torch.nn.Conv2d(
|
30 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
31 |
+
)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
35 |
+
if self.with_conv:
|
36 |
+
x = self.conv(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class Upsample1d(Upsample2d):
|
41 |
+
def __init__(self, in_channels, with_conv):
|
42 |
+
super().__init__(in_channels, with_conv)
|
43 |
+
if self.with_conv:
|
44 |
+
self.conv = torch.nn.Conv1d(
|
45 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
class Downsample2d(nn.Module):
|
50 |
+
def __init__(self, in_channels, with_conv):
|
51 |
+
super().__init__()
|
52 |
+
self.with_conv = with_conv
|
53 |
+
if self.with_conv:
|
54 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
55 |
+
self.conv = torch.nn.Conv2d(
|
56 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
57 |
+
)
|
58 |
+
self.pad = (0, 1, 0, 1)
|
59 |
+
else:
|
60 |
+
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
if self.with_conv: # bp: check self.avgpool and self.pad
|
64 |
+
x = torch.nn.functional.pad(x, self.pad, mode="constant", value=0)
|
65 |
+
x = self.conv(x)
|
66 |
+
else:
|
67 |
+
x = self.avg_pool(x)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class Downsample1d(Downsample2d):
|
72 |
+
def __init__(self, in_channels, with_conv):
|
73 |
+
super().__init__(in_channels, with_conv)
|
74 |
+
if self.with_conv:
|
75 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
76 |
+
# TODO: can we replace it just with conv2d with padding 1?
|
77 |
+
self.conv = torch.nn.Conv1d(
|
78 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
79 |
+
)
|
80 |
+
self.pad = (1, 1)
|
81 |
+
else:
|
82 |
+
self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2)
|
83 |
+
|
84 |
+
|
85 |
+
class ResnetBlock(nn.Module):
|
86 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
|
87 |
+
super().__init__()
|
88 |
+
self.in_channels = in_channels
|
89 |
+
out_channels = in_channels if out_channels is None else out_channels
|
90 |
+
self.out_channels = out_channels
|
91 |
+
self.use_conv_shortcut = conv_shortcut
|
92 |
+
|
93 |
+
self.norm1 = Normalize(in_channels)
|
94 |
+
self.conv1 = torch.nn.Conv2d(
|
95 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
96 |
+
)
|
97 |
+
|
98 |
+
self.norm2 = Normalize(out_channels)
|
99 |
+
self.dropout = torch.nn.Dropout(dropout)
|
100 |
+
self.conv2 = torch.nn.Conv2d(
|
101 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
102 |
+
)
|
103 |
+
if self.in_channels != self.out_channels:
|
104 |
+
if self.use_conv_shortcut:
|
105 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
106 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
110 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
h = x
|
115 |
+
h = self.norm1(h)
|
116 |
+
h = nonlinearity(h)
|
117 |
+
h = self.conv1(h)
|
118 |
+
|
119 |
+
h = self.norm2(h)
|
120 |
+
h = nonlinearity(h)
|
121 |
+
h = self.dropout(h)
|
122 |
+
h = self.conv2(h)
|
123 |
+
|
124 |
+
if self.in_channels != self.out_channels:
|
125 |
+
if self.use_conv_shortcut:
|
126 |
+
x = self.conv_shortcut(x)
|
127 |
+
else:
|
128 |
+
x = self.nin_shortcut(x)
|
129 |
+
|
130 |
+
return x + h
|
131 |
+
|
132 |
+
|
133 |
+
class ResnetBlock1d(ResnetBlock):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
*,
|
137 |
+
in_channels,
|
138 |
+
out_channels=None,
|
139 |
+
conv_shortcut=False,
|
140 |
+
dropout,
|
141 |
+
temb_channels=512
|
142 |
+
):
|
143 |
+
super().__init__(
|
144 |
+
in_channels=in_channels,
|
145 |
+
out_channels=out_channels,
|
146 |
+
conv_shortcut=conv_shortcut,
|
147 |
+
dropout=dropout,
|
148 |
+
)
|
149 |
+
|
150 |
+
self.conv1 = torch.nn.Conv1d(
|
151 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
152 |
+
)
|
153 |
+
self.conv2 = torch.nn.Conv1d(
|
154 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
155 |
+
)
|
156 |
+
if self.in_channels != self.out_channels:
|
157 |
+
if self.use_conv_shortcut:
|
158 |
+
self.conv_shortcut = torch.nn.Conv1d(
|
159 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
self.nin_shortcut = torch.nn.Conv1d(
|
163 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
164 |
+
)
|
165 |
+
|
166 |
+
|
167 |
+
class Encoder2d(nn.Module):
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
*,
|
171 |
+
ch,
|
172 |
+
ch_mult=(1, 2, 4, 8),
|
173 |
+
num_res_blocks,
|
174 |
+
dropout=0.0,
|
175 |
+
resamp_with_conv=True,
|
176 |
+
in_channels,
|
177 |
+
z_channels,
|
178 |
+
double_z=True,
|
179 |
+
**ignore_kwargs
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
self.ch = ch
|
183 |
+
self.num_resolutions = len(ch_mult)
|
184 |
+
self.num_res_blocks = num_res_blocks
|
185 |
+
self.in_channels = in_channels
|
186 |
+
|
187 |
+
# downsampling
|
188 |
+
self.conv_in = torch.nn.Conv2d(
|
189 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
190 |
+
)
|
191 |
+
|
192 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
193 |
+
self.down = nn.ModuleList()
|
194 |
+
for i_level in range(self.num_resolutions):
|
195 |
+
block = nn.ModuleList()
|
196 |
+
block_in = ch * in_ch_mult[i_level]
|
197 |
+
block_out = ch * ch_mult[i_level]
|
198 |
+
for i_block in range(self.num_res_blocks):
|
199 |
+
block.append(
|
200 |
+
ResnetBlock(
|
201 |
+
in_channels=block_in, out_channels=block_out, dropout=dropout
|
202 |
+
)
|
203 |
+
)
|
204 |
+
block_in = block_out
|
205 |
+
down = nn.Module()
|
206 |
+
down.block = block
|
207 |
+
if i_level != self.num_resolutions - 1:
|
208 |
+
down.downsample = Downsample2d(block_in, resamp_with_conv)
|
209 |
+
self.down.append(down)
|
210 |
+
|
211 |
+
# middle
|
212 |
+
self.mid = nn.Module()
|
213 |
+
self.mid.block_1 = ResnetBlock(
|
214 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout
|
215 |
+
)
|
216 |
+
self.mid.block_2 = ResnetBlock(
|
217 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout
|
218 |
+
)
|
219 |
+
|
220 |
+
# end
|
221 |
+
self.norm_out = Normalize(block_in)
|
222 |
+
self.conv_out = torch.nn.Conv2d(
|
223 |
+
block_in,
|
224 |
+
2 * z_channels if double_z else z_channels,
|
225 |
+
kernel_size=3,
|
226 |
+
stride=1,
|
227 |
+
padding=1,
|
228 |
+
)
|
229 |
+
|
230 |
+
def forward(self, x):
|
231 |
+
# downsampling
|
232 |
+
hs = [self.conv_in(x)]
|
233 |
+
for i_level in range(self.num_resolutions):
|
234 |
+
for i_block in range(self.num_res_blocks):
|
235 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
236 |
+
hs.append(h)
|
237 |
+
if i_level != self.num_resolutions - 1:
|
238 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
239 |
+
|
240 |
+
# middle
|
241 |
+
h = hs[-1]
|
242 |
+
h = self.mid.block_1(h)
|
243 |
+
h = self.mid.block_2(h)
|
244 |
+
|
245 |
+
# end
|
246 |
+
h = self.norm_out(h)
|
247 |
+
h = nonlinearity(h)
|
248 |
+
h = self.conv_out(h)
|
249 |
+
return h
|
250 |
+
|
251 |
+
|
252 |
+
# TODO: Encoder1d
|
253 |
+
class Encoder1d(Encoder2d):
|
254 |
+
...
|
255 |
+
|
256 |
+
|
257 |
+
class Decoder2d(nn.Module):
|
258 |
+
def __init__(
|
259 |
+
self,
|
260 |
+
*,
|
261 |
+
ch,
|
262 |
+
out_ch,
|
263 |
+
ch_mult=(1, 2, 4, 8),
|
264 |
+
num_res_blocks,
|
265 |
+
dropout=0.0,
|
266 |
+
resamp_with_conv=True,
|
267 |
+
in_channels,
|
268 |
+
z_channels,
|
269 |
+
give_pre_end=False,
|
270 |
+
**ignorekwargs
|
271 |
+
):
|
272 |
+
super().__init__()
|
273 |
+
self.ch = ch
|
274 |
+
self.num_resolutions = len(ch_mult)
|
275 |
+
self.num_res_blocks = num_res_blocks
|
276 |
+
self.in_channels = in_channels
|
277 |
+
self.give_pre_end = give_pre_end
|
278 |
+
|
279 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
280 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
281 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
282 |
+
# self.z_shape = (1,z_channels,curr_res,curr_res)
|
283 |
+
# print("Working with z of shape {} = {} dimensions.".format(
|
284 |
+
# self.z_shape, np.prod(self.z_shape)))
|
285 |
+
|
286 |
+
# z to block_in
|
287 |
+
self.conv_in = torch.nn.Conv2d(
|
288 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
289 |
+
)
|
290 |
+
|
291 |
+
# middle
|
292 |
+
self.mid = nn.Module()
|
293 |
+
self.mid.block_1 = ResnetBlock(
|
294 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout
|
295 |
+
)
|
296 |
+
self.mid.block_2 = ResnetBlock(
|
297 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout
|
298 |
+
)
|
299 |
+
|
300 |
+
# upsampling
|
301 |
+
self.up = nn.ModuleList()
|
302 |
+
for i_level in reversed(range(self.num_resolutions)):
|
303 |
+
block = nn.ModuleList()
|
304 |
+
attn = nn.ModuleList()
|
305 |
+
block_out = ch * ch_mult[i_level]
|
306 |
+
for i_block in range(self.num_res_blocks + 1):
|
307 |
+
block.append(
|
308 |
+
ResnetBlock(
|
309 |
+
in_channels=block_in, out_channels=block_out, dropout=dropout
|
310 |
+
)
|
311 |
+
)
|
312 |
+
block_in = block_out
|
313 |
+
up = nn.Module()
|
314 |
+
up.block = block
|
315 |
+
up.attn = attn
|
316 |
+
if i_level != 0:
|
317 |
+
up.upsample = Upsample2d(block_in, resamp_with_conv)
|
318 |
+
self.up.insert(0, up) # prepend to get consistent order
|
319 |
+
|
320 |
+
# end
|
321 |
+
self.norm_out = Normalize(block_in)
|
322 |
+
self.conv_out = torch.nn.Conv2d(
|
323 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
324 |
+
)
|
325 |
+
|
326 |
+
def forward(self, z):
|
327 |
+
self.last_z_shape = z.shape
|
328 |
+
|
329 |
+
# z to block_in
|
330 |
+
h = self.conv_in(z)
|
331 |
+
|
332 |
+
# middle
|
333 |
+
h = self.mid.block_1(h)
|
334 |
+
h = self.mid.block_2(h)
|
335 |
+
|
336 |
+
# upsampling
|
337 |
+
for i_level in reversed(range(self.num_resolutions)):
|
338 |
+
for i_block in range(self.num_res_blocks + 1):
|
339 |
+
h = self.up[i_level].block[i_block](h)
|
340 |
+
if i_level != 0:
|
341 |
+
h = self.up[i_level].upsample(h)
|
342 |
+
|
343 |
+
# end
|
344 |
+
if self.give_pre_end:
|
345 |
+
return h
|
346 |
+
|
347 |
+
h = self.norm_out(h)
|
348 |
+
h = nonlinearity(h)
|
349 |
+
h = self.conv_out(h)
|
350 |
+
return h
|
351 |
+
|
352 |
+
|
353 |
+
# TODO: decoder1d
|
354 |
+
class Decoder1d(Decoder2d):
|
355 |
+
...
|
356 |
+
|
357 |
+
|
358 |
+
class AutoencoderKL(nn.Module):
|
359 |
+
def __init__(self, cfg):
|
360 |
+
super().__init__()
|
361 |
+
self.cfg = cfg
|
362 |
+
self.encoder = Encoder2d(
|
363 |
+
ch=cfg.ch,
|
364 |
+
ch_mult=cfg.ch_mult,
|
365 |
+
num_res_blocks=cfg.num_res_blocks,
|
366 |
+
in_channels=cfg.in_channels,
|
367 |
+
z_channels=cfg.z_channels,
|
368 |
+
double_z=cfg.double_z,
|
369 |
+
)
|
370 |
+
self.decoder = Decoder2d(
|
371 |
+
ch=cfg.ch,
|
372 |
+
ch_mult=cfg.ch_mult,
|
373 |
+
num_res_blocks=cfg.num_res_blocks,
|
374 |
+
out_ch=cfg.out_ch,
|
375 |
+
z_channels=cfg.z_channels,
|
376 |
+
in_channels=None,
|
377 |
+
)
|
378 |
+
assert self.cfg.double_z
|
379 |
+
|
380 |
+
self.quant_conv = torch.nn.Conv2d(2 * cfg.z_channels, 2 * cfg.z_channels, 1)
|
381 |
+
self.post_quant_conv = torch.nn.Conv2d(cfg.z_channels, cfg.z_channels, 1)
|
382 |
+
self.embed_dim = cfg.z_channels
|
383 |
+
|
384 |
+
def encode(self, x):
|
385 |
+
h = self.encoder(x)
|
386 |
+
moments = self.quant_conv(h)
|
387 |
+
posterior = DiagonalGaussianDistribution(moments)
|
388 |
+
return posterior
|
389 |
+
|
390 |
+
def decode(self, z):
|
391 |
+
z = self.post_quant_conv(z)
|
392 |
+
dec = self.decoder(z)
|
393 |
+
return dec
|
394 |
+
|
395 |
+
def forward(self, input, sample_posterior=True):
|
396 |
+
posterior = self.encode(input)
|
397 |
+
if sample_posterior:
|
398 |
+
z = posterior.sample()
|
399 |
+
else:
|
400 |
+
z = posterior.mode()
|
401 |
+
dec = self.decode(z)
|
402 |
+
return dec, posterior
|
403 |
+
|
404 |
+
def get_last_layer(self):
|
405 |
+
return self.decoder.conv_out.weight
|
models/tta/autoencoder/autoencoder_dataset.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import random
|
7 |
+
import torch
|
8 |
+
from torch.nn.utils.rnn import pad_sequence
|
9 |
+
from utils.data_utils import *
|
10 |
+
from models.base.base_dataset import (
|
11 |
+
BaseCollator,
|
12 |
+
BaseDataset,
|
13 |
+
BaseTestDataset,
|
14 |
+
BaseTestCollator,
|
15 |
+
)
|
16 |
+
import librosa
|
17 |
+
|
18 |
+
|
19 |
+
class AutoencoderKLDataset(BaseDataset):
|
20 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
21 |
+
BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid)
|
22 |
+
|
23 |
+
cfg = self.cfg
|
24 |
+
|
25 |
+
# utt2melspec
|
26 |
+
if cfg.preprocess.use_melspec:
|
27 |
+
self.utt2melspec_path = {}
|
28 |
+
for utt_info in self.metadata:
|
29 |
+
dataset = utt_info["Dataset"]
|
30 |
+
uid = utt_info["Uid"]
|
31 |
+
utt = "{}_{}".format(dataset, uid)
|
32 |
+
|
33 |
+
self.utt2melspec_path[utt] = os.path.join(
|
34 |
+
cfg.preprocess.processed_dir,
|
35 |
+
dataset,
|
36 |
+
cfg.preprocess.melspec_dir,
|
37 |
+
uid + ".npy",
|
38 |
+
)
|
39 |
+
|
40 |
+
# utt2wav
|
41 |
+
if cfg.preprocess.use_wav:
|
42 |
+
self.utt2wav_path = {}
|
43 |
+
for utt_info in self.metadata:
|
44 |
+
dataset = utt_info["Dataset"]
|
45 |
+
uid = utt_info["Uid"]
|
46 |
+
utt = "{}_{}".format(dataset, uid)
|
47 |
+
|
48 |
+
self.utt2wav_path[utt] = os.path.join(
|
49 |
+
cfg.preprocess.processed_dir,
|
50 |
+
dataset,
|
51 |
+
cfg.preprocess.wav_dir,
|
52 |
+
uid + ".wav",
|
53 |
+
)
|
54 |
+
|
55 |
+
def __getitem__(self, index):
|
56 |
+
# melspec: (n_mels, T)
|
57 |
+
# wav: (T,)
|
58 |
+
|
59 |
+
single_feature = BaseDataset.__getitem__(self, index)
|
60 |
+
|
61 |
+
utt_info = self.metadata[index]
|
62 |
+
dataset = utt_info["Dataset"]
|
63 |
+
uid = utt_info["Uid"]
|
64 |
+
utt = "{}_{}".format(dataset, uid)
|
65 |
+
|
66 |
+
if self.cfg.preprocess.use_melspec:
|
67 |
+
single_feature["melspec"] = np.load(self.utt2melspec_path[utt])
|
68 |
+
|
69 |
+
if self.cfg.preprocess.use_wav:
|
70 |
+
wav, sr = librosa.load(
|
71 |
+
self.utt2wav_path[utt], sr=16000
|
72 |
+
) # hard coding for 16KHz...
|
73 |
+
single_feature["wav"] = wav
|
74 |
+
|
75 |
+
return single_feature
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return len(self.metadata)
|
79 |
+
|
80 |
+
def __len__(self):
|
81 |
+
return len(self.metadata)
|
82 |
+
|
83 |
+
|
84 |
+
class AutoencoderKLCollator(BaseCollator):
|
85 |
+
def __init__(self, cfg):
|
86 |
+
BaseCollator.__init__(self, cfg)
|
87 |
+
|
88 |
+
def __call__(self, batch):
|
89 |
+
# mel: (B, n_mels, T)
|
90 |
+
# wav (option): (B, T)
|
91 |
+
|
92 |
+
packed_batch_features = dict()
|
93 |
+
|
94 |
+
for key in batch[0].keys():
|
95 |
+
if key == "melspec":
|
96 |
+
packed_batch_features["melspec"] = torch.from_numpy(
|
97 |
+
np.array([b["melspec"][:, :624] for b in batch])
|
98 |
+
)
|
99 |
+
|
100 |
+
if key == "wav":
|
101 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
102 |
+
packed_batch_features[key] = pad_sequence(
|
103 |
+
values, batch_first=True, padding_value=0
|
104 |
+
)
|
105 |
+
|
106 |
+
return packed_batch_features
|
107 |
+
|
108 |
+
|
109 |
+
class AutoencoderKLTestDataset(BaseTestDataset):
|
110 |
+
...
|
111 |
+
|
112 |
+
|
113 |
+
class AutoencoderKLTestCollator(BaseTestCollator):
|
114 |
+
...
|
models/tta/autoencoder/autoencoder_loss.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import functools
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def hinge_d_loss(logits_real, logits_fake):
|
13 |
+
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
14 |
+
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
15 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
16 |
+
return d_loss
|
17 |
+
|
18 |
+
|
19 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
20 |
+
d_loss = 0.5 * (
|
21 |
+
torch.mean(F.softplus(-logits_real)) + torch.mean(F.softplus(logits_fake))
|
22 |
+
)
|
23 |
+
return d_loss
|
24 |
+
|
25 |
+
|
26 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
27 |
+
if global_step < threshold:
|
28 |
+
weight = value
|
29 |
+
return weight
|
30 |
+
|
31 |
+
|
32 |
+
class ActNorm(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self, num_features, logdet=False, affine=True, allow_reverse_init=False
|
35 |
+
):
|
36 |
+
assert affine
|
37 |
+
super().__init__()
|
38 |
+
self.logdet = logdet
|
39 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
40 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
41 |
+
self.allow_reverse_init = allow_reverse_init
|
42 |
+
|
43 |
+
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
44 |
+
|
45 |
+
def initialize(self, input):
|
46 |
+
with torch.no_grad():
|
47 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
48 |
+
mean = (
|
49 |
+
flatten.mean(1)
|
50 |
+
.unsqueeze(1)
|
51 |
+
.unsqueeze(2)
|
52 |
+
.unsqueeze(3)
|
53 |
+
.permute(1, 0, 2, 3)
|
54 |
+
)
|
55 |
+
std = (
|
56 |
+
flatten.std(1)
|
57 |
+
.unsqueeze(1)
|
58 |
+
.unsqueeze(2)
|
59 |
+
.unsqueeze(3)
|
60 |
+
.permute(1, 0, 2, 3)
|
61 |
+
)
|
62 |
+
|
63 |
+
self.loc.data.copy_(-mean)
|
64 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
65 |
+
|
66 |
+
def forward(self, input, reverse=False):
|
67 |
+
if reverse:
|
68 |
+
return self.reverse(input)
|
69 |
+
if len(input.shape) == 2:
|
70 |
+
input = input[:, :, None, None]
|
71 |
+
squeeze = True
|
72 |
+
else:
|
73 |
+
squeeze = False
|
74 |
+
|
75 |
+
_, _, height, width = input.shape
|
76 |
+
|
77 |
+
if self.training and self.initialized.item() == 0:
|
78 |
+
self.initialize(input)
|
79 |
+
self.initialized.fill_(1)
|
80 |
+
|
81 |
+
h = self.scale * (input + self.loc)
|
82 |
+
|
83 |
+
if squeeze:
|
84 |
+
h = h.squeeze(-1).squeeze(-1)
|
85 |
+
|
86 |
+
if self.logdet:
|
87 |
+
log_abs = torch.log(torch.abs(self.scale))
|
88 |
+
logdet = height * width * torch.sum(log_abs)
|
89 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
90 |
+
return h, logdet
|
91 |
+
|
92 |
+
return h
|
93 |
+
|
94 |
+
def reverse(self, output):
|
95 |
+
if self.training and self.initialized.item() == 0:
|
96 |
+
if not self.allow_reverse_init:
|
97 |
+
raise RuntimeError(
|
98 |
+
"Initializing ActNorm in reverse direction is "
|
99 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
self.initialize(output)
|
103 |
+
self.initialized.fill_(1)
|
104 |
+
|
105 |
+
if len(output.shape) == 2:
|
106 |
+
output = output[:, :, None, None]
|
107 |
+
squeeze = True
|
108 |
+
else:
|
109 |
+
squeeze = False
|
110 |
+
|
111 |
+
h = output / self.scale - self.loc
|
112 |
+
|
113 |
+
if squeeze:
|
114 |
+
h = h.squeeze(-1).squeeze(-1)
|
115 |
+
return h
|
116 |
+
|
117 |
+
|
118 |
+
def weights_init(m):
|
119 |
+
classname = m.__class__.__name__
|
120 |
+
if classname.find("Conv") != -1:
|
121 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
122 |
+
elif classname.find("BatchNorm") != -1:
|
123 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
124 |
+
nn.init.constant_(m.bias.data, 0)
|
125 |
+
|
126 |
+
|
127 |
+
class NLayerDiscriminator(nn.Module):
|
128 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
129 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
133 |
+
"""Construct a PatchGAN discriminator
|
134 |
+
Parameters:
|
135 |
+
input_nc (int) -- the number of channels in input images
|
136 |
+
ndf (int) -- the number of filters in the last conv layer
|
137 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
138 |
+
norm_layer -- normalization layer
|
139 |
+
"""
|
140 |
+
super(NLayerDiscriminator, self).__init__()
|
141 |
+
if not use_actnorm:
|
142 |
+
norm_layer = nn.BatchNorm2d
|
143 |
+
else:
|
144 |
+
norm_layer = ActNorm
|
145 |
+
if (
|
146 |
+
type(norm_layer) == functools.partial
|
147 |
+
): # no need to use bias as BatchNorm2d has affine parameters
|
148 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
149 |
+
else:
|
150 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
151 |
+
|
152 |
+
kw = 4
|
153 |
+
padw = 1
|
154 |
+
sequence = [
|
155 |
+
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
156 |
+
nn.LeakyReLU(0.2, True),
|
157 |
+
]
|
158 |
+
nf_mult = 1
|
159 |
+
nf_mult_prev = 1
|
160 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
161 |
+
nf_mult_prev = nf_mult
|
162 |
+
nf_mult = min(2**n, 8)
|
163 |
+
sequence += [
|
164 |
+
nn.Conv2d(
|
165 |
+
ndf * nf_mult_prev,
|
166 |
+
ndf * nf_mult,
|
167 |
+
kernel_size=kw,
|
168 |
+
stride=2,
|
169 |
+
padding=padw,
|
170 |
+
bias=use_bias,
|
171 |
+
),
|
172 |
+
norm_layer(ndf * nf_mult),
|
173 |
+
nn.LeakyReLU(0.2, True),
|
174 |
+
]
|
175 |
+
|
176 |
+
nf_mult_prev = nf_mult
|
177 |
+
nf_mult = min(2**n_layers, 8)
|
178 |
+
sequence += [
|
179 |
+
nn.Conv2d(
|
180 |
+
ndf * nf_mult_prev,
|
181 |
+
ndf * nf_mult,
|
182 |
+
kernel_size=kw,
|
183 |
+
stride=1,
|
184 |
+
padding=padw,
|
185 |
+
bias=use_bias,
|
186 |
+
),
|
187 |
+
norm_layer(ndf * nf_mult),
|
188 |
+
nn.LeakyReLU(0.2, True),
|
189 |
+
]
|
190 |
+
|
191 |
+
sequence += [
|
192 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
193 |
+
] # output 1 channel prediction map
|
194 |
+
self.main = nn.Sequential(*sequence)
|
195 |
+
|
196 |
+
def forward(self, input):
|
197 |
+
"""Standard forward."""
|
198 |
+
return self.main(input)
|
199 |
+
|
200 |
+
|
201 |
+
class AutoencoderLossWithDiscriminator(nn.Module):
|
202 |
+
def __init__(self, cfg):
|
203 |
+
super().__init__()
|
204 |
+
self.cfg = cfg
|
205 |
+
self.kl_weight = cfg.kl_weight
|
206 |
+
self.logvar = nn.Parameter(torch.ones(size=()) * cfg.logvar_init)
|
207 |
+
|
208 |
+
self.discriminator = NLayerDiscriminator(
|
209 |
+
input_nc=cfg.disc_in_channels,
|
210 |
+
n_layers=cfg.disc_num_layers,
|
211 |
+
use_actnorm=cfg.use_actnorm,
|
212 |
+
).apply(weights_init)
|
213 |
+
|
214 |
+
self.discriminator_iter_start = cfg.disc_start
|
215 |
+
self.discriminator_weight = cfg.disc_weight
|
216 |
+
self.disc_factor = cfg.disc_factor
|
217 |
+
self.disc_loss = hinge_d_loss
|
218 |
+
|
219 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
|
220 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
221 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
222 |
+
|
223 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
224 |
+
d_weight = torch.clamp(
|
225 |
+
d_weight, self.cfg.min_adapt_d_weight, self.cfg.max_adapt_d_weight
|
226 |
+
).detach()
|
227 |
+
d_weight = d_weight * self.discriminator_weight
|
228 |
+
return d_weight
|
229 |
+
|
230 |
+
def forward(
|
231 |
+
self,
|
232 |
+
inputs,
|
233 |
+
reconstructions,
|
234 |
+
posteriors,
|
235 |
+
optimizer_idx,
|
236 |
+
global_step,
|
237 |
+
last_layer,
|
238 |
+
split="train",
|
239 |
+
weights=None,
|
240 |
+
):
|
241 |
+
rec_loss = torch.abs(
|
242 |
+
inputs.contiguous() - reconstructions.contiguous()
|
243 |
+
) # l1 loss
|
244 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
245 |
+
weighted_nll_loss = nll_loss
|
246 |
+
if weights is not None:
|
247 |
+
weighted_nll_loss = weights * nll_loss
|
248 |
+
# weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
249 |
+
weighted_nll_loss = torch.mean(weighted_nll_loss)
|
250 |
+
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
251 |
+
nll_loss = torch.mean(nll_loss)
|
252 |
+
kl_loss = posteriors.kl()
|
253 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
254 |
+
# ? kl_loss = torch.mean(kl_loss)
|
255 |
+
|
256 |
+
# now the GAN part
|
257 |
+
if optimizer_idx == 0:
|
258 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
259 |
+
g_loss = -torch.mean(logits_fake)
|
260 |
+
|
261 |
+
if self.disc_factor > 0.0:
|
262 |
+
try:
|
263 |
+
d_weight = self.calculate_adaptive_weight(
|
264 |
+
nll_loss, g_loss, last_layer=last_layer
|
265 |
+
)
|
266 |
+
except RuntimeError:
|
267 |
+
assert not self.training
|
268 |
+
d_weight = torch.tensor(0.0)
|
269 |
+
else:
|
270 |
+
d_weight = torch.tensor(0.0)
|
271 |
+
|
272 |
+
disc_factor = adopt_weight(
|
273 |
+
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
274 |
+
)
|
275 |
+
|
276 |
+
total_loss = (
|
277 |
+
weighted_nll_loss
|
278 |
+
+ self.kl_weight * kl_loss
|
279 |
+
+ d_weight * disc_factor * g_loss
|
280 |
+
)
|
281 |
+
|
282 |
+
return {
|
283 |
+
"loss": total_loss,
|
284 |
+
"kl_loss": kl_loss,
|
285 |
+
"rec_loss": rec_loss.mean(),
|
286 |
+
"nll_loss": nll_loss,
|
287 |
+
"g_loss": g_loss,
|
288 |
+
"d_weight": d_weight,
|
289 |
+
"disc_factor": torch.tensor(disc_factor),
|
290 |
+
}
|
291 |
+
|
292 |
+
if optimizer_idx == 1:
|
293 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
294 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
295 |
+
|
296 |
+
disc_factor = adopt_weight(
|
297 |
+
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
298 |
+
)
|
299 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
300 |
+
|
301 |
+
return {
|
302 |
+
"d_loss": d_loss,
|
303 |
+
"logits_real": logits_real.mean(),
|
304 |
+
"logits_fake": logits_fake.mean(),
|
305 |
+
}
|
models/tta/autoencoder/autoencoder_trainer.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from models.base.base_trainer import BaseTrainer
|
8 |
+
from models.tta.autoencoder.autoencoder_dataset import (
|
9 |
+
AutoencoderKLDataset,
|
10 |
+
AutoencoderKLCollator,
|
11 |
+
)
|
12 |
+
from models.tta.autoencoder.autoencoder import AutoencoderKL
|
13 |
+
from models.tta.autoencoder.autoencoder_loss import AutoencoderLossWithDiscriminator
|
14 |
+
from torch.optim import Adam, AdamW
|
15 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
16 |
+
from torch.nn import MSELoss, L1Loss
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
19 |
+
|
20 |
+
|
21 |
+
class AutoencoderKLTrainer(BaseTrainer):
|
22 |
+
def __init__(self, args, cfg):
|
23 |
+
BaseTrainer.__init__(self, args, cfg)
|
24 |
+
self.cfg = cfg
|
25 |
+
self.save_config_file()
|
26 |
+
|
27 |
+
def build_dataset(self):
|
28 |
+
return AutoencoderKLDataset, AutoencoderKLCollator
|
29 |
+
|
30 |
+
def build_optimizer(self):
|
31 |
+
opt_ae = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam)
|
32 |
+
opt_disc = torch.optim.AdamW(
|
33 |
+
self.criterion.discriminator.parameters(), **self.cfg.train.adam
|
34 |
+
)
|
35 |
+
optimizer = {"opt_ae": opt_ae, "opt_disc": opt_disc}
|
36 |
+
return optimizer
|
37 |
+
|
38 |
+
def build_data_loader(self):
|
39 |
+
Dataset, Collator = self.build_dataset()
|
40 |
+
# build dataset instance for each dataset and combine them by ConcatDataset
|
41 |
+
datasets_list = []
|
42 |
+
for dataset in self.cfg.dataset:
|
43 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
44 |
+
datasets_list.append(subdataset)
|
45 |
+
train_dataset = ConcatDataset(datasets_list)
|
46 |
+
|
47 |
+
train_collate = Collator(self.cfg)
|
48 |
+
|
49 |
+
# use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
|
50 |
+
train_loader = DataLoader(
|
51 |
+
train_dataset,
|
52 |
+
collate_fn=train_collate,
|
53 |
+
num_workers=self.args.num_workers,
|
54 |
+
batch_size=self.cfg.train.batch_size,
|
55 |
+
pin_memory=False,
|
56 |
+
)
|
57 |
+
if not self.cfg.train.ddp or self.args.local_rank == 0:
|
58 |
+
datasets_list = []
|
59 |
+
for dataset in self.cfg.dataset:
|
60 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
61 |
+
datasets_list.append(subdataset)
|
62 |
+
valid_dataset = ConcatDataset(datasets_list)
|
63 |
+
valid_collate = Collator(self.cfg)
|
64 |
+
|
65 |
+
valid_loader = DataLoader(
|
66 |
+
valid_dataset,
|
67 |
+
collate_fn=valid_collate,
|
68 |
+
num_workers=1,
|
69 |
+
batch_size=self.cfg.train.batch_size,
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
raise NotImplementedError("DDP is not supported yet.")
|
73 |
+
# valid_loader = None
|
74 |
+
data_loader = {"train": train_loader, "valid": valid_loader}
|
75 |
+
return data_loader
|
76 |
+
|
77 |
+
# TODO: check it...
|
78 |
+
def build_scheduler(self):
|
79 |
+
return None
|
80 |
+
# return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau)
|
81 |
+
|
82 |
+
def write_summary(self, losses, stats):
|
83 |
+
for key, value in losses.items():
|
84 |
+
self.sw.add_scalar(key, value, self.step)
|
85 |
+
|
86 |
+
def write_valid_summary(self, losses, stats):
|
87 |
+
for key, value in losses.items():
|
88 |
+
self.sw.add_scalar(key, value, self.step)
|
89 |
+
|
90 |
+
def build_criterion(self):
|
91 |
+
return AutoencoderLossWithDiscriminator(self.cfg.model.loss)
|
92 |
+
|
93 |
+
def get_state_dict(self):
|
94 |
+
if self.scheduler != None:
|
95 |
+
state_dict = {
|
96 |
+
"model": self.model.state_dict(),
|
97 |
+
"optimizer_ae": self.optimizer["opt_ae"].state_dict(),
|
98 |
+
"optimizer_disc": self.optimizer["opt_disc"].state_dict(),
|
99 |
+
"scheduler": self.scheduler.state_dict(),
|
100 |
+
"step": self.step,
|
101 |
+
"epoch": self.epoch,
|
102 |
+
"batch_size": self.cfg.train.batch_size,
|
103 |
+
}
|
104 |
+
else:
|
105 |
+
state_dict = {
|
106 |
+
"model": self.model.state_dict(),
|
107 |
+
"optimizer_ae": self.optimizer["opt_ae"].state_dict(),
|
108 |
+
"optimizer_disc": self.optimizer["opt_disc"].state_dict(),
|
109 |
+
"step": self.step,
|
110 |
+
"epoch": self.epoch,
|
111 |
+
"batch_size": self.cfg.train.batch_size,
|
112 |
+
}
|
113 |
+
return state_dict
|
114 |
+
|
115 |
+
def load_model(self, checkpoint):
|
116 |
+
self.step = checkpoint["step"]
|
117 |
+
self.epoch = checkpoint["epoch"]
|
118 |
+
|
119 |
+
self.model.load_state_dict(checkpoint["model"])
|
120 |
+
self.optimizer["opt_ae"].load_state_dict(checkpoint["optimizer_ae"])
|
121 |
+
self.optimizer["opt_disc"].load_state_dict(checkpoint["optimizer_disc"])
|
122 |
+
if self.scheduler != None:
|
123 |
+
self.scheduler.load_state_dict(checkpoint["scheduler"])
|
124 |
+
|
125 |
+
def build_model(self):
|
126 |
+
self.model = AutoencoderKL(self.cfg.model.autoencoderkl)
|
127 |
+
return self.model
|
128 |
+
|
129 |
+
# TODO: train step
|
130 |
+
def train_step(self, data):
|
131 |
+
global_step = self.step
|
132 |
+
optimizer_idx = global_step % 2
|
133 |
+
|
134 |
+
train_losses = {}
|
135 |
+
total_loss = 0
|
136 |
+
train_states = {}
|
137 |
+
|
138 |
+
inputs = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
|
139 |
+
reconstructions, posterior = self.model(inputs)
|
140 |
+
# train_stats.update(stat)
|
141 |
+
|
142 |
+
train_losses = self.criterion(
|
143 |
+
inputs=inputs,
|
144 |
+
reconstructions=reconstructions,
|
145 |
+
posteriors=posterior,
|
146 |
+
optimizer_idx=optimizer_idx,
|
147 |
+
global_step=global_step,
|
148 |
+
last_layer=self.model.get_last_layer(),
|
149 |
+
split="train",
|
150 |
+
)
|
151 |
+
|
152 |
+
if optimizer_idx == 0:
|
153 |
+
total_loss = train_losses["loss"]
|
154 |
+
self.optimizer["opt_ae"].zero_grad()
|
155 |
+
total_loss.backward()
|
156 |
+
self.optimizer["opt_ae"].step()
|
157 |
+
|
158 |
+
else:
|
159 |
+
total_loss = train_losses["d_loss"]
|
160 |
+
self.optimizer["opt_disc"].zero_grad()
|
161 |
+
total_loss.backward()
|
162 |
+
self.optimizer["opt_disc"].step()
|
163 |
+
|
164 |
+
for item in train_losses:
|
165 |
+
train_losses[item] = train_losses[item].item()
|
166 |
+
|
167 |
+
return train_losses, train_states, total_loss.item()
|
168 |
+
|
169 |
+
# TODO: eval step
|
170 |
+
@torch.no_grad()
|
171 |
+
def eval_step(self, data, index):
|
172 |
+
valid_loss = {}
|
173 |
+
total_valid_loss = 0
|
174 |
+
valid_stats = {}
|
175 |
+
|
176 |
+
inputs = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
|
177 |
+
reconstructions, posterior = self.model(inputs)
|
178 |
+
|
179 |
+
loss = F.l1_loss(inputs, reconstructions)
|
180 |
+
valid_loss["loss"] = loss
|
181 |
+
|
182 |
+
total_valid_loss += loss
|
183 |
+
|
184 |
+
for item in valid_loss:
|
185 |
+
valid_loss[item] = valid_loss[item].item()
|
186 |
+
|
187 |
+
return valid_loss, valid_stats, total_valid_loss.item()
|
models/tta/ldm/__init__.py
ADDED
File without changes
|
models/tta/ldm/attention.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from inspect import isfunction
|
7 |
+
import math
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import nn, einsum
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
|
13 |
+
|
14 |
+
class CheckpointFunction(torch.autograd.Function):
|
15 |
+
@staticmethod
|
16 |
+
def forward(ctx, run_function, length, *args):
|
17 |
+
ctx.run_function = run_function
|
18 |
+
ctx.input_tensors = list(args[:length])
|
19 |
+
ctx.input_params = list(args[length:])
|
20 |
+
|
21 |
+
with torch.no_grad():
|
22 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
23 |
+
return output_tensors
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def backward(ctx, *output_grads):
|
27 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
28 |
+
with torch.enable_grad():
|
29 |
+
# Fixes a bug where the first op in run_function modifies the
|
30 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
31 |
+
# Tensors.
|
32 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
33 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
34 |
+
input_grads = torch.autograd.grad(
|
35 |
+
output_tensors,
|
36 |
+
ctx.input_tensors + ctx.input_params,
|
37 |
+
output_grads,
|
38 |
+
allow_unused=True,
|
39 |
+
)
|
40 |
+
del ctx.input_tensors
|
41 |
+
del ctx.input_params
|
42 |
+
del output_tensors
|
43 |
+
return (None, None) + input_grads
|
44 |
+
|
45 |
+
|
46 |
+
def checkpoint(func, inputs, params, flag):
|
47 |
+
"""
|
48 |
+
Evaluate a function without caching intermediate activations, allowing for
|
49 |
+
reduced memory at the expense of extra compute in the backward pass.
|
50 |
+
:param func: the function to evaluate.
|
51 |
+
:param inputs: the argument sequence to pass to `func`.
|
52 |
+
:param params: a sequence of parameters `func` depends on but does not
|
53 |
+
explicitly take as arguments.
|
54 |
+
:param flag: if False, disable gradient checkpointing.
|
55 |
+
"""
|
56 |
+
if flag:
|
57 |
+
args = tuple(inputs) + tuple(params)
|
58 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
59 |
+
else:
|
60 |
+
return func(*inputs)
|
61 |
+
|
62 |
+
|
63 |
+
def exists(val):
|
64 |
+
return val is not None
|
65 |
+
|
66 |
+
|
67 |
+
def uniq(arr):
|
68 |
+
return {el: True for el in arr}.keys()
|
69 |
+
|
70 |
+
|
71 |
+
def default(val, d):
|
72 |
+
if exists(val):
|
73 |
+
return val
|
74 |
+
return d() if isfunction(d) else d
|
75 |
+
|
76 |
+
|
77 |
+
def max_neg_value(t):
|
78 |
+
return -torch.finfo(t.dtype).max
|
79 |
+
|
80 |
+
|
81 |
+
def init_(tensor):
|
82 |
+
dim = tensor.shape[-1]
|
83 |
+
std = 1 / math.sqrt(dim)
|
84 |
+
tensor.uniform_(-std, std)
|
85 |
+
return tensor
|
86 |
+
|
87 |
+
|
88 |
+
# feedforward
|
89 |
+
class GEGLU(nn.Module):
|
90 |
+
def __init__(self, dim_in, dim_out):
|
91 |
+
super().__init__()
|
92 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
96 |
+
return x * F.gelu(gate)
|
97 |
+
|
98 |
+
|
99 |
+
class FeedForward(nn.Module):
|
100 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
101 |
+
super().__init__()
|
102 |
+
inner_dim = int(dim * mult)
|
103 |
+
dim_out = default(dim_out, dim)
|
104 |
+
project_in = (
|
105 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
106 |
+
if not glu
|
107 |
+
else GEGLU(dim, inner_dim)
|
108 |
+
)
|
109 |
+
|
110 |
+
self.net = nn.Sequential(
|
111 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
112 |
+
)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
return self.net(x)
|
116 |
+
|
117 |
+
|
118 |
+
def zero_module(module):
|
119 |
+
"""
|
120 |
+
Zero out the parameters of a module and return it.
|
121 |
+
"""
|
122 |
+
for p in module.parameters():
|
123 |
+
p.detach().zero_()
|
124 |
+
return module
|
125 |
+
|
126 |
+
|
127 |
+
def Normalize(in_channels):
|
128 |
+
return torch.nn.GroupNorm(
|
129 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
class LinearAttention(nn.Module):
|
134 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
135 |
+
super().__init__()
|
136 |
+
self.heads = heads
|
137 |
+
hidden_dim = dim_head * heads
|
138 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
139 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
b, c, h, w = x.shape
|
143 |
+
qkv = self.to_qkv(x)
|
144 |
+
q, k, v = rearrange(
|
145 |
+
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
146 |
+
)
|
147 |
+
k = k.softmax(dim=-1)
|
148 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
149 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
150 |
+
out = rearrange(
|
151 |
+
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
152 |
+
)
|
153 |
+
return self.to_out(out)
|
154 |
+
|
155 |
+
|
156 |
+
class SpatialSelfAttention(nn.Module):
|
157 |
+
def __init__(self, in_channels):
|
158 |
+
super().__init__()
|
159 |
+
self.in_channels = in_channels
|
160 |
+
|
161 |
+
self.norm = Normalize(in_channels)
|
162 |
+
self.q = torch.nn.Conv2d(
|
163 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
164 |
+
)
|
165 |
+
self.k = torch.nn.Conv2d(
|
166 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
167 |
+
)
|
168 |
+
self.v = torch.nn.Conv2d(
|
169 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
170 |
+
)
|
171 |
+
self.proj_out = torch.nn.Conv2d(
|
172 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
173 |
+
)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
h_ = x
|
177 |
+
h_ = self.norm(h_)
|
178 |
+
q = self.q(h_)
|
179 |
+
k = self.k(h_)
|
180 |
+
v = self.v(h_)
|
181 |
+
|
182 |
+
# compute attention
|
183 |
+
b, c, h, w = q.shape
|
184 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
185 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
186 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
187 |
+
|
188 |
+
w_ = w_ * (int(c) ** (-0.5))
|
189 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
190 |
+
|
191 |
+
# attend to values
|
192 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
193 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
194 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
195 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
196 |
+
h_ = self.proj_out(h_)
|
197 |
+
|
198 |
+
return x + h_
|
199 |
+
|
200 |
+
|
201 |
+
class CrossAttention(nn.Module):
|
202 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
203 |
+
super().__init__()
|
204 |
+
inner_dim = dim_head * heads
|
205 |
+
context_dim = default(context_dim, query_dim)
|
206 |
+
|
207 |
+
self.scale = dim_head**-0.5
|
208 |
+
self.heads = heads
|
209 |
+
|
210 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
211 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
212 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
213 |
+
|
214 |
+
self.to_out = nn.Sequential(
|
215 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
216 |
+
)
|
217 |
+
|
218 |
+
def forward(self, x, context=None, mask=None):
|
219 |
+
h = self.heads
|
220 |
+
|
221 |
+
q = self.to_q(x)
|
222 |
+
context = default(context, x)
|
223 |
+
k = self.to_k(context)
|
224 |
+
v = self.to_v(context)
|
225 |
+
|
226 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
227 |
+
|
228 |
+
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
229 |
+
|
230 |
+
if exists(mask):
|
231 |
+
mask = rearrange(mask, "b ... -> b (...)")
|
232 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
233 |
+
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
234 |
+
sim.masked_fill_(~mask, max_neg_value)
|
235 |
+
|
236 |
+
# attention, what we cannot get enough of
|
237 |
+
attn = sim.softmax(dim=-1)
|
238 |
+
|
239 |
+
out = einsum("b i j, b j d -> b i d", attn, v)
|
240 |
+
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
241 |
+
return self.to_out(out)
|
242 |
+
|
243 |
+
|
244 |
+
class BasicTransformerBlock(nn.Module):
|
245 |
+
def __init__(
|
246 |
+
self,
|
247 |
+
dim,
|
248 |
+
n_heads,
|
249 |
+
d_head,
|
250 |
+
dropout=0.0,
|
251 |
+
context_dim=None,
|
252 |
+
gated_ff=True,
|
253 |
+
checkpoint=True,
|
254 |
+
):
|
255 |
+
super().__init__()
|
256 |
+
self.attn1 = CrossAttention(
|
257 |
+
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
258 |
+
) # is a self-attention
|
259 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
260 |
+
self.attn2 = CrossAttention(
|
261 |
+
query_dim=dim,
|
262 |
+
context_dim=context_dim,
|
263 |
+
heads=n_heads,
|
264 |
+
dim_head=d_head,
|
265 |
+
dropout=dropout,
|
266 |
+
) # is self-attn if context is none
|
267 |
+
self.norm1 = nn.LayerNorm(dim)
|
268 |
+
self.norm2 = nn.LayerNorm(dim)
|
269 |
+
self.norm3 = nn.LayerNorm(dim)
|
270 |
+
self.checkpoint = checkpoint
|
271 |
+
|
272 |
+
def forward(self, x, context=None):
|
273 |
+
return checkpoint(
|
274 |
+
self._forward, (x, context), self.parameters(), self.checkpoint
|
275 |
+
)
|
276 |
+
|
277 |
+
def _forward(self, x, context=None):
|
278 |
+
x = self.attn1(self.norm1(x)) + x
|
279 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
280 |
+
x = self.ff(self.norm3(x)) + x
|
281 |
+
return x
|
282 |
+
|
283 |
+
|
284 |
+
class SpatialTransformer(nn.Module):
|
285 |
+
"""
|
286 |
+
Transformer block for image-like data.
|
287 |
+
First, project the input (aka embedding)
|
288 |
+
and reshape to b, t, d.
|
289 |
+
Then apply standard transformer action.
|
290 |
+
Finally, reshape to image
|
291 |
+
"""
|
292 |
+
|
293 |
+
def __init__(
|
294 |
+
self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None
|
295 |
+
):
|
296 |
+
super().__init__()
|
297 |
+
self.in_channels = in_channels
|
298 |
+
inner_dim = n_heads * d_head
|
299 |
+
self.norm = Normalize(in_channels)
|
300 |
+
|
301 |
+
self.proj_in = nn.Conv2d(
|
302 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
303 |
+
)
|
304 |
+
|
305 |
+
self.transformer_blocks = nn.ModuleList(
|
306 |
+
[
|
307 |
+
BasicTransformerBlock(
|
308 |
+
inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
|
309 |
+
)
|
310 |
+
for d in range(depth)
|
311 |
+
]
|
312 |
+
)
|
313 |
+
|
314 |
+
self.proj_out = zero_module(
|
315 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
316 |
+
)
|
317 |
+
|
318 |
+
def forward(self, x, context=None):
|
319 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
320 |
+
b, c, h, w = x.shape
|
321 |
+
x_in = x
|
322 |
+
x = self.norm(x)
|
323 |
+
x = self.proj_in(x)
|
324 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
325 |
+
for block in self.transformer_blocks:
|
326 |
+
x = block(x, context=context)
|
327 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
328 |
+
x = self.proj_out(x)
|
329 |
+
return x + x_in
|
models/tta/ldm/audioldm.py
ADDED
@@ -0,0 +1,926 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from abc import abstractmethod
|
7 |
+
from functools import partial
|
8 |
+
import math
|
9 |
+
from typing import Iterable
|
10 |
+
|
11 |
+
import os
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import numpy as np
|
16 |
+
from einops import repeat
|
17 |
+
|
18 |
+
from models.tta.ldm.attention import SpatialTransformer
|
19 |
+
|
20 |
+
# from attention import SpatialTransformer
|
21 |
+
|
22 |
+
|
23 |
+
class CheckpointFunction(torch.autograd.Function):
|
24 |
+
@staticmethod
|
25 |
+
def forward(ctx, run_function, length, *args):
|
26 |
+
ctx.run_function = run_function
|
27 |
+
ctx.input_tensors = list(args[:length])
|
28 |
+
ctx.input_params = list(args[length:])
|
29 |
+
|
30 |
+
with torch.no_grad():
|
31 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
32 |
+
return output_tensors
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def backward(ctx, *output_grads):
|
36 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
37 |
+
with torch.enable_grad():
|
38 |
+
# Fixes a bug where the first op in run_function modifies the
|
39 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
40 |
+
# Tensors.
|
41 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
42 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
43 |
+
input_grads = torch.autograd.grad(
|
44 |
+
output_tensors,
|
45 |
+
ctx.input_tensors + ctx.input_params,
|
46 |
+
output_grads,
|
47 |
+
allow_unused=True,
|
48 |
+
)
|
49 |
+
del ctx.input_tensors
|
50 |
+
del ctx.input_params
|
51 |
+
del output_tensors
|
52 |
+
return (None, None) + input_grads
|
53 |
+
|
54 |
+
|
55 |
+
def checkpoint(func, inputs, params, flag):
|
56 |
+
"""
|
57 |
+
Evaluate a function without caching intermediate activations, allowing for
|
58 |
+
reduced memory at the expense of extra compute in the backward pass.
|
59 |
+
:param func: the function to evaluate.
|
60 |
+
:param inputs: the argument sequence to pass to `func`.
|
61 |
+
:param params: a sequence of parameters `func` depends on but does not
|
62 |
+
explicitly take as arguments.
|
63 |
+
:param flag: if False, disable gradient checkpointing.
|
64 |
+
"""
|
65 |
+
if flag:
|
66 |
+
args = tuple(inputs) + tuple(params)
|
67 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
68 |
+
else:
|
69 |
+
return func(*inputs)
|
70 |
+
|
71 |
+
|
72 |
+
def zero_module(module):
|
73 |
+
"""
|
74 |
+
Zero out the parameters of a module and return it.
|
75 |
+
"""
|
76 |
+
for p in module.parameters():
|
77 |
+
p.detach().zero_()
|
78 |
+
return module
|
79 |
+
|
80 |
+
|
81 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
82 |
+
"""
|
83 |
+
Create sinusoidal timestep embeddings.
|
84 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
85 |
+
These may be fractional.
|
86 |
+
:param dim: the dimension of the output.
|
87 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
88 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
89 |
+
"""
|
90 |
+
if not repeat_only:
|
91 |
+
half = dim // 2
|
92 |
+
freqs = torch.exp(
|
93 |
+
-math.log(max_period)
|
94 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
95 |
+
/ half
|
96 |
+
).to(device=timesteps.device)
|
97 |
+
args = timesteps[:, None].float() * freqs[None]
|
98 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
99 |
+
if dim % 2:
|
100 |
+
embedding = torch.cat(
|
101 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
embedding = repeat(timesteps, "b -> b d", d=dim)
|
105 |
+
return embedding
|
106 |
+
|
107 |
+
|
108 |
+
class GroupNorm32(nn.GroupNorm):
|
109 |
+
def forward(self, x):
|
110 |
+
return super().forward(x.float()).type(x.dtype)
|
111 |
+
|
112 |
+
|
113 |
+
def normalization(channels):
|
114 |
+
"""
|
115 |
+
Make a standard normalization layer.
|
116 |
+
:param channels: number of input channels.
|
117 |
+
:return: an nn.Module for normalization.
|
118 |
+
"""
|
119 |
+
return GroupNorm32(32, channels)
|
120 |
+
|
121 |
+
|
122 |
+
def count_flops_attn(model, _x, y):
|
123 |
+
"""
|
124 |
+
A counter for the `thop` package to count the operations in an
|
125 |
+
attention operation.
|
126 |
+
Meant to be used like:
|
127 |
+
macs, params = thop.profile(
|
128 |
+
model,
|
129 |
+
inputs=(inputs, timestamps),
|
130 |
+
custom_ops={QKVAttention: QKVAttention.count_flops},
|
131 |
+
)
|
132 |
+
"""
|
133 |
+
b, c, *spatial = y[0].shape
|
134 |
+
num_spatial = int(np.prod(spatial))
|
135 |
+
# We perform two matmuls with the same number of ops.
|
136 |
+
# The first computes the weight matrix, the second computes
|
137 |
+
# the combination of the value vectors.
|
138 |
+
matmul_ops = 2 * b * (num_spatial**2) * c
|
139 |
+
model.total_ops += torch.DoubleTensor([matmul_ops])
|
140 |
+
|
141 |
+
|
142 |
+
def conv_nd(dims, *args, **kwargs):
|
143 |
+
"""
|
144 |
+
Create a 1D, 2D, or 3D convolution module.
|
145 |
+
"""
|
146 |
+
if dims == 1:
|
147 |
+
return nn.Conv1d(*args, **kwargs)
|
148 |
+
elif dims == 2:
|
149 |
+
return nn.Conv2d(*args, **kwargs)
|
150 |
+
elif dims == 3:
|
151 |
+
return nn.Conv3d(*args, **kwargs)
|
152 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
153 |
+
|
154 |
+
|
155 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
156 |
+
"""
|
157 |
+
Create a 1D, 2D, or 3D average pooling module.
|
158 |
+
"""
|
159 |
+
if dims == 1:
|
160 |
+
return nn.AvgPool1d(*args, **kwargs)
|
161 |
+
elif dims == 2:
|
162 |
+
return nn.AvgPool2d(*args, **kwargs)
|
163 |
+
elif dims == 3:
|
164 |
+
return nn.AvgPool3d(*args, **kwargs)
|
165 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
166 |
+
|
167 |
+
|
168 |
+
class QKVAttention(nn.Module):
|
169 |
+
"""
|
170 |
+
A module which performs QKV attention and splits in a different order.
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, n_heads):
|
174 |
+
super().__init__()
|
175 |
+
self.n_heads = n_heads
|
176 |
+
|
177 |
+
def forward(self, qkv):
|
178 |
+
"""
|
179 |
+
Apply QKV attention.
|
180 |
+
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
181 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
182 |
+
"""
|
183 |
+
|
184 |
+
bs, width, length = qkv.shape
|
185 |
+
assert width % (3 * self.n_heads) == 0
|
186 |
+
ch = width // (3 * self.n_heads)
|
187 |
+
q, k, v = qkv.chunk(3, dim=1) # [N x (H * C) x T]
|
188 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
189 |
+
weight = torch.einsum(
|
190 |
+
"bct,bcs->bts",
|
191 |
+
(q * scale).view(bs * self.n_heads, ch, length),
|
192 |
+
(k * scale).view(bs * self.n_heads, ch, length),
|
193 |
+
) # More stable with f16 than dividing afterwards
|
194 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
195 |
+
a = torch.einsum(
|
196 |
+
"bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)
|
197 |
+
)
|
198 |
+
return a.reshape(bs, -1, length)
|
199 |
+
|
200 |
+
@staticmethod
|
201 |
+
def count_flops(model, _x, y):
|
202 |
+
return count_flops_attn(model, _x, y)
|
203 |
+
|
204 |
+
|
205 |
+
class QKVAttentionLegacy(nn.Module):
|
206 |
+
"""
|
207 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
208 |
+
"""
|
209 |
+
|
210 |
+
def __init__(self, n_heads):
|
211 |
+
super().__init__()
|
212 |
+
self.n_heads = n_heads
|
213 |
+
|
214 |
+
def forward(self, qkv):
|
215 |
+
"""
|
216 |
+
Apply QKV attention.
|
217 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
218 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
219 |
+
"""
|
220 |
+
bs, width, length = qkv.shape
|
221 |
+
assert width % (3 * self.n_heads) == 0
|
222 |
+
ch = width // (3 * self.n_heads)
|
223 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
224 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
225 |
+
weight = torch.einsum(
|
226 |
+
"bct,bcs->bts", q * scale, k * scale
|
227 |
+
) # More stable with f16 than dividing afterwards
|
228 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
229 |
+
a = torch.einsum("bts,bcs->bct", weight, v)
|
230 |
+
return a.reshape(bs, -1, length)
|
231 |
+
|
232 |
+
@staticmethod
|
233 |
+
def count_flops(model, _x, y):
|
234 |
+
return count_flops_attn(model, _x, y)
|
235 |
+
|
236 |
+
|
237 |
+
class AttentionPool2d(nn.Module):
|
238 |
+
"""
|
239 |
+
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
240 |
+
"""
|
241 |
+
|
242 |
+
def __init__(
|
243 |
+
self,
|
244 |
+
spacial_dim: int,
|
245 |
+
embed_dim: int,
|
246 |
+
num_heads_channels: int,
|
247 |
+
output_dim: int = None,
|
248 |
+
):
|
249 |
+
super().__init__()
|
250 |
+
self.positional_embedding = nn.Parameter(
|
251 |
+
torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
|
252 |
+
)
|
253 |
+
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
254 |
+
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
255 |
+
self.num_heads = embed_dim // num_heads_channels
|
256 |
+
self.attention = QKVAttention(self.num_heads)
|
257 |
+
|
258 |
+
def forward(self, x):
|
259 |
+
b, c, *_spatial = x.shape
|
260 |
+
x = x.reshape(b, c, -1) # NC(HW)
|
261 |
+
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
262 |
+
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
263 |
+
x = self.qkv_proj(x)
|
264 |
+
x = self.attention(x)
|
265 |
+
x = self.c_proj(x)
|
266 |
+
return x[:, :, 0]
|
267 |
+
|
268 |
+
|
269 |
+
class TimestepBlock(nn.Module):
|
270 |
+
"""
|
271 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
272 |
+
"""
|
273 |
+
|
274 |
+
@abstractmethod
|
275 |
+
def forward(self, x, emb):
|
276 |
+
"""
|
277 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
278 |
+
"""
|
279 |
+
|
280 |
+
|
281 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
282 |
+
"""
|
283 |
+
A sequential module that passes timestep embeddings to the children that
|
284 |
+
support it as an extra input.
|
285 |
+
"""
|
286 |
+
|
287 |
+
def forward(self, x, emb, context=None):
|
288 |
+
for layer in self:
|
289 |
+
if isinstance(layer, TimestepBlock):
|
290 |
+
x = layer(x, emb)
|
291 |
+
elif isinstance(layer, SpatialTransformer):
|
292 |
+
x = layer(x, context)
|
293 |
+
else:
|
294 |
+
x = layer(x)
|
295 |
+
return x
|
296 |
+
|
297 |
+
|
298 |
+
class Upsample(nn.Module):
|
299 |
+
"""
|
300 |
+
An upsampling layer with an optional convolution.
|
301 |
+
:param channels: channels in the inputs and outputs.
|
302 |
+
:param use_conv: a bool determining if a convolution is applied.
|
303 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
304 |
+
upsampling occurs in the inner-two dimensions.
|
305 |
+
"""
|
306 |
+
|
307 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
308 |
+
super().__init__()
|
309 |
+
self.channels = channels
|
310 |
+
self.out_channels = out_channels or channels
|
311 |
+
self.use_conv = use_conv
|
312 |
+
self.dims = dims
|
313 |
+
if use_conv:
|
314 |
+
self.conv = conv_nd(
|
315 |
+
dims, self.channels, self.out_channels, 3, padding=padding
|
316 |
+
)
|
317 |
+
|
318 |
+
def forward(self, x):
|
319 |
+
assert x.shape[1] == self.channels
|
320 |
+
if self.dims == 3:
|
321 |
+
x = F.interpolate(
|
322 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
326 |
+
if self.use_conv:
|
327 |
+
x = self.conv(x)
|
328 |
+
return x
|
329 |
+
|
330 |
+
|
331 |
+
class TransposedUpsample(nn.Module):
|
332 |
+
"Learned 2x upsampling without padding"
|
333 |
+
|
334 |
+
def __init__(self, channels, out_channels=None, ks=5):
|
335 |
+
super().__init__()
|
336 |
+
self.channels = channels
|
337 |
+
self.out_channels = out_channels or channels
|
338 |
+
|
339 |
+
self.up = nn.ConvTranspose2d(
|
340 |
+
self.channels, self.out_channels, kernel_size=ks, stride=2
|
341 |
+
)
|
342 |
+
|
343 |
+
def forward(self, x):
|
344 |
+
return self.up(x)
|
345 |
+
|
346 |
+
|
347 |
+
class Downsample(nn.Module):
|
348 |
+
"""
|
349 |
+
A downsampling layer with an optional convolution.
|
350 |
+
:param channels: channels in the inputs and outputs.
|
351 |
+
:param use_conv: a bool determining if a convolution is applied.
|
352 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
353 |
+
downsampling occurs in the inner-two dimensions.
|
354 |
+
"""
|
355 |
+
|
356 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
357 |
+
super().__init__()
|
358 |
+
self.channels = channels
|
359 |
+
self.out_channels = out_channels or channels
|
360 |
+
self.use_conv = use_conv
|
361 |
+
self.dims = dims
|
362 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
363 |
+
if use_conv:
|
364 |
+
self.op = conv_nd(
|
365 |
+
dims,
|
366 |
+
self.channels,
|
367 |
+
self.out_channels,
|
368 |
+
3,
|
369 |
+
stride=stride,
|
370 |
+
padding=padding,
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
assert self.channels == self.out_channels
|
374 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
375 |
+
|
376 |
+
def forward(self, x):
|
377 |
+
assert x.shape[1] == self.channels
|
378 |
+
return self.op(x)
|
379 |
+
|
380 |
+
|
381 |
+
class ResBlock(TimestepBlock):
|
382 |
+
"""
|
383 |
+
A residual block that can optionally change the number of channels.
|
384 |
+
:param channels: the number of input channels.
|
385 |
+
:param emb_channels: the number of timestep embedding channels.
|
386 |
+
:param dropout: the rate of dropout.
|
387 |
+
:param out_channels: if specified, the number of out channels.
|
388 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
389 |
+
convolution instead of a smaller 1x1 convolution to change the
|
390 |
+
channels in the skip connection.
|
391 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
392 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
393 |
+
:param up: if True, use this block for upsampling.
|
394 |
+
:param down: if True, use this block for downsampling.
|
395 |
+
"""
|
396 |
+
|
397 |
+
def __init__(
|
398 |
+
self,
|
399 |
+
channels,
|
400 |
+
emb_channels,
|
401 |
+
dropout,
|
402 |
+
out_channels=None,
|
403 |
+
use_conv=False,
|
404 |
+
use_scale_shift_norm=False,
|
405 |
+
dims=2,
|
406 |
+
use_checkpoint=False,
|
407 |
+
up=False,
|
408 |
+
down=False,
|
409 |
+
):
|
410 |
+
super().__init__()
|
411 |
+
self.channels = channels
|
412 |
+
self.emb_channels = emb_channels
|
413 |
+
self.dropout = dropout
|
414 |
+
self.out_channels = out_channels or channels
|
415 |
+
self.use_conv = use_conv
|
416 |
+
self.use_checkpoint = use_checkpoint
|
417 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
418 |
+
|
419 |
+
self.in_layers = nn.Sequential(
|
420 |
+
normalization(channels),
|
421 |
+
nn.SiLU(),
|
422 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
423 |
+
)
|
424 |
+
|
425 |
+
self.updown = up or down
|
426 |
+
|
427 |
+
if up:
|
428 |
+
self.h_upd = Upsample(channels, False, dims)
|
429 |
+
self.x_upd = Upsample(channels, False, dims)
|
430 |
+
elif down:
|
431 |
+
self.h_upd = Downsample(channels, False, dims)
|
432 |
+
self.x_upd = Downsample(channels, False, dims)
|
433 |
+
else:
|
434 |
+
self.h_upd = self.x_upd = nn.Identity()
|
435 |
+
|
436 |
+
self.emb_layers = nn.Sequential(
|
437 |
+
nn.SiLU(),
|
438 |
+
nn.Linear(
|
439 |
+
emb_channels,
|
440 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
441 |
+
),
|
442 |
+
)
|
443 |
+
self.out_layers = nn.Sequential(
|
444 |
+
normalization(self.out_channels),
|
445 |
+
nn.SiLU(),
|
446 |
+
nn.Dropout(p=dropout),
|
447 |
+
zero_module(
|
448 |
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
449 |
+
),
|
450 |
+
)
|
451 |
+
|
452 |
+
if self.out_channels == channels:
|
453 |
+
self.skip_connection = nn.Identity()
|
454 |
+
elif use_conv:
|
455 |
+
self.skip_connection = conv_nd(
|
456 |
+
dims, channels, self.out_channels, 3, padding=1
|
457 |
+
)
|
458 |
+
else:
|
459 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
460 |
+
|
461 |
+
def forward(self, x, emb):
|
462 |
+
"""
|
463 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
464 |
+
:param x: an [N x C x ...] Tensor of features.
|
465 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
466 |
+
:return: an [N x C x ...] Tensor of outputs.
|
467 |
+
"""
|
468 |
+
return checkpoint(
|
469 |
+
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
470 |
+
)
|
471 |
+
|
472 |
+
def _forward(self, x, emb):
|
473 |
+
if self.updown:
|
474 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
475 |
+
h = in_rest(x)
|
476 |
+
h = self.h_upd(h)
|
477 |
+
x = self.x_upd(x)
|
478 |
+
h = in_conv(h)
|
479 |
+
else:
|
480 |
+
h = self.in_layers(x)
|
481 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
482 |
+
while len(emb_out.shape) < len(h.shape):
|
483 |
+
emb_out = emb_out[..., None]
|
484 |
+
if self.use_scale_shift_norm:
|
485 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
486 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
487 |
+
h = out_norm(h) * (1 + scale) + shift
|
488 |
+
h = out_rest(h)
|
489 |
+
else:
|
490 |
+
h = h + emb_out
|
491 |
+
h = self.out_layers(h)
|
492 |
+
return self.skip_connection(x) + h
|
493 |
+
|
494 |
+
|
495 |
+
class AttentionBlock(nn.Module):
|
496 |
+
"""
|
497 |
+
An attention block that allows spatial positions to attend to each other.
|
498 |
+
Originally ported from here, but adapted to the N-d case.
|
499 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
500 |
+
"""
|
501 |
+
|
502 |
+
def __init__(
|
503 |
+
self,
|
504 |
+
channels,
|
505 |
+
num_heads=1,
|
506 |
+
num_head_channels=-1,
|
507 |
+
use_checkpoint=False,
|
508 |
+
use_new_attention_order=False,
|
509 |
+
):
|
510 |
+
super().__init__()
|
511 |
+
self.channels = channels
|
512 |
+
if num_head_channels == -1:
|
513 |
+
self.num_heads = num_heads
|
514 |
+
else:
|
515 |
+
assert (
|
516 |
+
channels % num_head_channels == 0
|
517 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
518 |
+
self.num_heads = channels // num_head_channels
|
519 |
+
self.use_checkpoint = use_checkpoint
|
520 |
+
self.norm = normalization(channels)
|
521 |
+
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
522 |
+
if use_new_attention_order:
|
523 |
+
# split qkv before split heads
|
524 |
+
self.attention = QKVAttention(self.num_heads)
|
525 |
+
else:
|
526 |
+
# split heads before split qkv
|
527 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
528 |
+
|
529 |
+
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
530 |
+
|
531 |
+
def forward(self, x):
|
532 |
+
return checkpoint(
|
533 |
+
self._forward, (x,), self.parameters(), True
|
534 |
+
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
535 |
+
# return pt_checkpoint(self._forward, x) # pytorch
|
536 |
+
|
537 |
+
def _forward(self, x):
|
538 |
+
b, c, *spatial = x.shape
|
539 |
+
x = x.reshape(b, c, -1)
|
540 |
+
qkv = self.qkv(self.norm(x))
|
541 |
+
h = self.attention(qkv)
|
542 |
+
h = self.proj_out(h)
|
543 |
+
return (x + h).reshape(b, c, *spatial)
|
544 |
+
|
545 |
+
|
546 |
+
class UNetModel(nn.Module):
|
547 |
+
"""
|
548 |
+
The full UNet model with attention and timestep embedding.
|
549 |
+
:param in_channels: channels in the input Tensor.
|
550 |
+
:param model_channels: base channel count for the model.
|
551 |
+
:param out_channels: channels in the output Tensor.
|
552 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
553 |
+
:param attention_resolutions: a collection of downsample rates at which
|
554 |
+
attention will take place. May be a set, list, or tuple.
|
555 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
556 |
+
will be used.
|
557 |
+
:param dropout: the dropout probability.
|
558 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
559 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
560 |
+
downsampling.
|
561 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
562 |
+
:param num_classes: if specified (as an int), then this model will be
|
563 |
+
class-conditional with `num_classes` classes.
|
564 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
565 |
+
:param num_heads: the number of attention heads in each attention layer.
|
566 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
567 |
+
a fixed channel width per attention head.
|
568 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
569 |
+
of heads for upsampling. Deprecated.
|
570 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
571 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
572 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
573 |
+
increased efficiency.
|
574 |
+
"""
|
575 |
+
|
576 |
+
def __init__(
|
577 |
+
self,
|
578 |
+
image_size,
|
579 |
+
in_channels,
|
580 |
+
model_channels,
|
581 |
+
out_channels,
|
582 |
+
num_res_blocks,
|
583 |
+
attention_resolutions,
|
584 |
+
dropout=0,
|
585 |
+
channel_mult=(1, 2, 4, 8),
|
586 |
+
conv_resample=True,
|
587 |
+
dims=2,
|
588 |
+
num_classes=None,
|
589 |
+
use_checkpoint=False,
|
590 |
+
use_fp16=False,
|
591 |
+
num_heads=-1,
|
592 |
+
num_head_channels=-1,
|
593 |
+
num_heads_upsample=-1,
|
594 |
+
use_scale_shift_norm=False,
|
595 |
+
resblock_updown=False,
|
596 |
+
use_new_attention_order=False,
|
597 |
+
use_spatial_transformer=False, # custom transformer support
|
598 |
+
transformer_depth=1, # custom transformer support
|
599 |
+
context_dim=None, # custom transformer support
|
600 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
601 |
+
legacy=True,
|
602 |
+
):
|
603 |
+
super().__init__()
|
604 |
+
if use_spatial_transformer:
|
605 |
+
assert (
|
606 |
+
context_dim is not None
|
607 |
+
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
608 |
+
|
609 |
+
if context_dim is not None:
|
610 |
+
assert (
|
611 |
+
use_spatial_transformer
|
612 |
+
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
613 |
+
from omegaconf.listconfig import ListConfig
|
614 |
+
|
615 |
+
if type(context_dim) == ListConfig:
|
616 |
+
context_dim = list(context_dim)
|
617 |
+
|
618 |
+
if num_heads_upsample == -1:
|
619 |
+
num_heads_upsample = num_heads
|
620 |
+
|
621 |
+
if num_heads == -1:
|
622 |
+
assert (
|
623 |
+
num_head_channels != -1
|
624 |
+
), "Either num_heads or num_head_channels has to be set"
|
625 |
+
|
626 |
+
if num_head_channels == -1:
|
627 |
+
assert (
|
628 |
+
num_heads != -1
|
629 |
+
), "Either num_heads or num_head_channels has to be set"
|
630 |
+
|
631 |
+
self.image_size = image_size
|
632 |
+
self.in_channels = in_channels
|
633 |
+
self.model_channels = model_channels
|
634 |
+
self.out_channels = out_channels
|
635 |
+
self.num_res_blocks = num_res_blocks
|
636 |
+
self.attention_resolutions = attention_resolutions
|
637 |
+
self.dropout = dropout
|
638 |
+
self.channel_mult = channel_mult
|
639 |
+
self.conv_resample = conv_resample
|
640 |
+
self.num_classes = num_classes
|
641 |
+
self.use_checkpoint = use_checkpoint
|
642 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
643 |
+
self.num_heads = num_heads
|
644 |
+
self.num_head_channels = num_head_channels
|
645 |
+
self.num_heads_upsample = num_heads_upsample
|
646 |
+
self.predict_codebook_ids = n_embed is not None
|
647 |
+
|
648 |
+
time_embed_dim = model_channels * 4
|
649 |
+
self.time_embed = nn.Sequential(
|
650 |
+
nn.Linear(model_channels, time_embed_dim),
|
651 |
+
nn.SiLU(),
|
652 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
653 |
+
)
|
654 |
+
|
655 |
+
if self.num_classes is not None:
|
656 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
657 |
+
|
658 |
+
self.input_blocks = nn.ModuleList(
|
659 |
+
[
|
660 |
+
TimestepEmbedSequential(
|
661 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
662 |
+
)
|
663 |
+
]
|
664 |
+
)
|
665 |
+
self._feature_size = model_channels
|
666 |
+
input_block_chans = [model_channels]
|
667 |
+
ch = model_channels
|
668 |
+
ds = 1
|
669 |
+
for level, mult in enumerate(channel_mult):
|
670 |
+
for _ in range(num_res_blocks):
|
671 |
+
layers = [
|
672 |
+
ResBlock(
|
673 |
+
ch,
|
674 |
+
time_embed_dim,
|
675 |
+
dropout,
|
676 |
+
out_channels=mult * model_channels,
|
677 |
+
dims=dims,
|
678 |
+
use_checkpoint=use_checkpoint,
|
679 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
680 |
+
)
|
681 |
+
]
|
682 |
+
ch = mult * model_channels
|
683 |
+
if ds in attention_resolutions:
|
684 |
+
if num_head_channels == -1:
|
685 |
+
dim_head = ch // num_heads
|
686 |
+
else:
|
687 |
+
num_heads = ch // num_head_channels
|
688 |
+
dim_head = num_head_channels
|
689 |
+
if legacy:
|
690 |
+
# num_heads = 1
|
691 |
+
dim_head = (
|
692 |
+
ch // num_heads
|
693 |
+
if use_spatial_transformer
|
694 |
+
else num_head_channels
|
695 |
+
)
|
696 |
+
layers.append(
|
697 |
+
AttentionBlock(
|
698 |
+
ch,
|
699 |
+
use_checkpoint=use_checkpoint,
|
700 |
+
num_heads=num_heads,
|
701 |
+
num_head_channels=dim_head,
|
702 |
+
use_new_attention_order=use_new_attention_order,
|
703 |
+
)
|
704 |
+
if not use_spatial_transformer
|
705 |
+
else SpatialTransformer(
|
706 |
+
ch,
|
707 |
+
num_heads,
|
708 |
+
dim_head,
|
709 |
+
depth=transformer_depth,
|
710 |
+
context_dim=context_dim,
|
711 |
+
)
|
712 |
+
)
|
713 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
714 |
+
self._feature_size += ch
|
715 |
+
input_block_chans.append(ch)
|
716 |
+
if level != len(channel_mult) - 1:
|
717 |
+
out_ch = ch
|
718 |
+
self.input_blocks.append(
|
719 |
+
TimestepEmbedSequential(
|
720 |
+
ResBlock(
|
721 |
+
ch,
|
722 |
+
time_embed_dim,
|
723 |
+
dropout,
|
724 |
+
out_channels=out_ch,
|
725 |
+
dims=dims,
|
726 |
+
use_checkpoint=use_checkpoint,
|
727 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
728 |
+
down=True,
|
729 |
+
)
|
730 |
+
if resblock_updown
|
731 |
+
else Downsample(
|
732 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
733 |
+
)
|
734 |
+
)
|
735 |
+
)
|
736 |
+
ch = out_ch
|
737 |
+
input_block_chans.append(ch)
|
738 |
+
ds *= 2
|
739 |
+
self._feature_size += ch
|
740 |
+
|
741 |
+
if num_head_channels == -1:
|
742 |
+
dim_head = ch // num_heads
|
743 |
+
else:
|
744 |
+
num_heads = ch // num_head_channels
|
745 |
+
dim_head = num_head_channels
|
746 |
+
if legacy:
|
747 |
+
# num_heads = 1
|
748 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
749 |
+
self.middle_block = TimestepEmbedSequential(
|
750 |
+
ResBlock(
|
751 |
+
ch,
|
752 |
+
time_embed_dim,
|
753 |
+
dropout,
|
754 |
+
dims=dims,
|
755 |
+
use_checkpoint=use_checkpoint,
|
756 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
757 |
+
),
|
758 |
+
AttentionBlock(
|
759 |
+
ch,
|
760 |
+
use_checkpoint=use_checkpoint,
|
761 |
+
num_heads=num_heads,
|
762 |
+
num_head_channels=dim_head,
|
763 |
+
use_new_attention_order=use_new_attention_order,
|
764 |
+
)
|
765 |
+
if not use_spatial_transformer
|
766 |
+
else SpatialTransformer(
|
767 |
+
ch,
|
768 |
+
num_heads,
|
769 |
+
dim_head,
|
770 |
+
depth=transformer_depth,
|
771 |
+
context_dim=context_dim,
|
772 |
+
),
|
773 |
+
ResBlock(
|
774 |
+
ch,
|
775 |
+
time_embed_dim,
|
776 |
+
dropout,
|
777 |
+
dims=dims,
|
778 |
+
use_checkpoint=use_checkpoint,
|
779 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
780 |
+
),
|
781 |
+
)
|
782 |
+
self._feature_size += ch
|
783 |
+
|
784 |
+
self.output_blocks = nn.ModuleList([])
|
785 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
786 |
+
for i in range(num_res_blocks + 1):
|
787 |
+
ich = input_block_chans.pop()
|
788 |
+
layers = [
|
789 |
+
ResBlock(
|
790 |
+
ch + ich,
|
791 |
+
time_embed_dim,
|
792 |
+
dropout,
|
793 |
+
out_channels=model_channels * mult,
|
794 |
+
dims=dims,
|
795 |
+
use_checkpoint=use_checkpoint,
|
796 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
797 |
+
)
|
798 |
+
]
|
799 |
+
ch = model_channels * mult
|
800 |
+
if ds in attention_resolutions:
|
801 |
+
if num_head_channels == -1:
|
802 |
+
dim_head = ch // num_heads
|
803 |
+
else:
|
804 |
+
num_heads = ch // num_head_channels
|
805 |
+
dim_head = num_head_channels
|
806 |
+
if legacy:
|
807 |
+
# num_heads = 1
|
808 |
+
dim_head = (
|
809 |
+
ch // num_heads
|
810 |
+
if use_spatial_transformer
|
811 |
+
else num_head_channels
|
812 |
+
)
|
813 |
+
layers.append(
|
814 |
+
AttentionBlock(
|
815 |
+
ch,
|
816 |
+
use_checkpoint=use_checkpoint,
|
817 |
+
num_heads=num_heads_upsample,
|
818 |
+
num_head_channels=dim_head,
|
819 |
+
use_new_attention_order=use_new_attention_order,
|
820 |
+
)
|
821 |
+
if not use_spatial_transformer
|
822 |
+
else SpatialTransformer(
|
823 |
+
ch,
|
824 |
+
num_heads,
|
825 |
+
dim_head,
|
826 |
+
depth=transformer_depth,
|
827 |
+
context_dim=context_dim,
|
828 |
+
)
|
829 |
+
)
|
830 |
+
if level and i == num_res_blocks:
|
831 |
+
out_ch = ch
|
832 |
+
layers.append(
|
833 |
+
ResBlock(
|
834 |
+
ch,
|
835 |
+
time_embed_dim,
|
836 |
+
dropout,
|
837 |
+
out_channels=out_ch,
|
838 |
+
dims=dims,
|
839 |
+
use_checkpoint=use_checkpoint,
|
840 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
841 |
+
up=True,
|
842 |
+
)
|
843 |
+
if resblock_updown
|
844 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
845 |
+
)
|
846 |
+
ds //= 2
|
847 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
848 |
+
self._feature_size += ch
|
849 |
+
|
850 |
+
self.out = nn.Sequential(
|
851 |
+
normalization(ch),
|
852 |
+
nn.SiLU(),
|
853 |
+
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
854 |
+
)
|
855 |
+
if self.predict_codebook_ids:
|
856 |
+
self.id_predictor = nn.Sequential(
|
857 |
+
normalization(ch),
|
858 |
+
conv_nd(dims, model_channels, n_embed, 1),
|
859 |
+
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
860 |
+
)
|
861 |
+
|
862 |
+
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
863 |
+
"""
|
864 |
+
Apply the model to an input batch.
|
865 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
866 |
+
:param timesteps: a 1-D batch of timesteps.
|
867 |
+
:param context: conditioning plugged in via crossattn
|
868 |
+
:param y: an [N] Tensor of labels, if class-conditional.
|
869 |
+
:return: an [N x C x ...] Tensor of outputs.
|
870 |
+
"""
|
871 |
+
assert (y is not None) == (
|
872 |
+
self.num_classes is not None
|
873 |
+
), "must specify y if and only if the model is class-conditional"
|
874 |
+
hs = []
|
875 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
876 |
+
emb = self.time_embed(t_emb)
|
877 |
+
|
878 |
+
if self.num_classes is not None:
|
879 |
+
assert y.shape == (x.shape[0],)
|
880 |
+
emb = emb + self.label_emb(y)
|
881 |
+
|
882 |
+
h = x.type(self.dtype)
|
883 |
+
for module in self.input_blocks:
|
884 |
+
h = module(h, emb, context)
|
885 |
+
hs.append(h)
|
886 |
+
h = self.middle_block(h, emb, context)
|
887 |
+
for module in self.output_blocks:
|
888 |
+
# print(h.shape, hs[-1].shape)
|
889 |
+
if h.shape != hs[-1].shape:
|
890 |
+
if h.shape[-1] > hs[-1].shape[-1]:
|
891 |
+
h = h[:, :, :, : hs[-1].shape[-1]]
|
892 |
+
if h.shape[-2] > hs[-1].shape[-2]:
|
893 |
+
h = h[:, :, : hs[-1].shape[-2], :]
|
894 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
895 |
+
h = module(h, emb, context)
|
896 |
+
# print(h.shape)
|
897 |
+
h = h.type(x.dtype)
|
898 |
+
if self.predict_codebook_ids:
|
899 |
+
return self.id_predictor(h)
|
900 |
+
else:
|
901 |
+
return self.out(h)
|
902 |
+
|
903 |
+
|
904 |
+
class AudioLDM(nn.Module):
|
905 |
+
def __init__(self, cfg):
|
906 |
+
super().__init__()
|
907 |
+
self.cfg = cfg
|
908 |
+
self.unet = UNetModel(
|
909 |
+
image_size=cfg.image_size,
|
910 |
+
in_channels=cfg.in_channels,
|
911 |
+
out_channels=cfg.out_channels,
|
912 |
+
model_channels=cfg.model_channels,
|
913 |
+
attention_resolutions=cfg.attention_resolutions,
|
914 |
+
num_res_blocks=cfg.num_res_blocks,
|
915 |
+
channel_mult=cfg.channel_mult,
|
916 |
+
num_heads=cfg.num_heads,
|
917 |
+
use_spatial_transformer=cfg.use_spatial_transformer,
|
918 |
+
transformer_depth=cfg.transformer_depth,
|
919 |
+
context_dim=cfg.context_dim,
|
920 |
+
use_checkpoint=cfg.use_checkpoint,
|
921 |
+
legacy=cfg.legacy,
|
922 |
+
)
|
923 |
+
|
924 |
+
def forward(self, x, timesteps=None, context=None, y=None):
|
925 |
+
x = self.unet(x=x, timesteps=timesteps, context=context, y=y)
|
926 |
+
return x
|
models/tta/ldm/audioldm_dataset.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import random
|
7 |
+
import torch
|
8 |
+
from torch.nn.utils.rnn import pad_sequence
|
9 |
+
from utils.data_utils import *
|
10 |
+
|
11 |
+
|
12 |
+
from models.base.base_dataset import (
|
13 |
+
BaseCollator,
|
14 |
+
BaseDataset,
|
15 |
+
BaseTestDataset,
|
16 |
+
BaseTestCollator,
|
17 |
+
)
|
18 |
+
import librosa
|
19 |
+
|
20 |
+
from transformers import AutoTokenizer
|
21 |
+
|
22 |
+
|
23 |
+
class AudioLDMDataset(BaseDataset):
|
24 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
25 |
+
BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid)
|
26 |
+
|
27 |
+
self.cfg = cfg
|
28 |
+
|
29 |
+
# utt2melspec
|
30 |
+
if cfg.preprocess.use_melspec:
|
31 |
+
self.utt2melspec_path = {}
|
32 |
+
for utt_info in self.metadata:
|
33 |
+
dataset = utt_info["Dataset"]
|
34 |
+
uid = utt_info["Uid"]
|
35 |
+
utt = "{}_{}".format(dataset, uid)
|
36 |
+
|
37 |
+
self.utt2melspec_path[utt] = os.path.join(
|
38 |
+
cfg.preprocess.processed_dir,
|
39 |
+
dataset,
|
40 |
+
cfg.preprocess.melspec_dir,
|
41 |
+
uid + ".npy",
|
42 |
+
)
|
43 |
+
|
44 |
+
# utt2wav
|
45 |
+
if cfg.preprocess.use_wav:
|
46 |
+
self.utt2wav_path = {}
|
47 |
+
for utt_info in self.metadata:
|
48 |
+
dataset = utt_info["Dataset"]
|
49 |
+
uid = utt_info["Uid"]
|
50 |
+
utt = "{}_{}".format(dataset, uid)
|
51 |
+
|
52 |
+
self.utt2wav_path[utt] = os.path.join(
|
53 |
+
cfg.preprocess.processed_dir,
|
54 |
+
dataset,
|
55 |
+
cfg.preprocess.wav_dir,
|
56 |
+
uid + ".wav",
|
57 |
+
)
|
58 |
+
|
59 |
+
# utt2caption
|
60 |
+
if cfg.preprocess.use_caption:
|
61 |
+
self.utt2caption = {}
|
62 |
+
for utt_info in self.metadata:
|
63 |
+
dataset = utt_info["Dataset"]
|
64 |
+
uid = utt_info["Uid"]
|
65 |
+
utt = "{}_{}".format(dataset, uid)
|
66 |
+
|
67 |
+
self.utt2caption[utt] = utt_info["Caption"]
|
68 |
+
|
69 |
+
def __getitem__(self, index):
|
70 |
+
# melspec: (n_mels, T)
|
71 |
+
# wav: (T,)
|
72 |
+
|
73 |
+
single_feature = BaseDataset.__getitem__(self, index)
|
74 |
+
|
75 |
+
utt_info = self.metadata[index]
|
76 |
+
dataset = utt_info["Dataset"]
|
77 |
+
uid = utt_info["Uid"]
|
78 |
+
utt = "{}_{}".format(dataset, uid)
|
79 |
+
|
80 |
+
if self.cfg.preprocess.use_melspec:
|
81 |
+
single_feature["melspec"] = np.load(self.utt2melspec_path[utt])
|
82 |
+
|
83 |
+
if self.cfg.preprocess.use_wav:
|
84 |
+
wav, sr = librosa.load(
|
85 |
+
self.utt2wav_path[utt], sr=16000
|
86 |
+
) # hard coding for 16KHz...
|
87 |
+
single_feature["wav"] = wav
|
88 |
+
|
89 |
+
if self.cfg.preprocess.use_caption:
|
90 |
+
cond_mask = np.random.choice(
|
91 |
+
[1, 0],
|
92 |
+
p=[
|
93 |
+
self.cfg.preprocess.cond_mask_prob,
|
94 |
+
1 - self.cfg.preprocess.cond_mask_prob,
|
95 |
+
],
|
96 |
+
) # (0.1, 0.9)
|
97 |
+
if cond_mask:
|
98 |
+
single_feature["caption"] = ""
|
99 |
+
else:
|
100 |
+
single_feature["caption"] = self.utt2caption[utt]
|
101 |
+
|
102 |
+
return single_feature
|
103 |
+
|
104 |
+
def __len__(self):
|
105 |
+
return len(self.metadata)
|
106 |
+
|
107 |
+
|
108 |
+
class AudioLDMCollator(BaseCollator):
|
109 |
+
def __init__(self, cfg):
|
110 |
+
BaseCollator.__init__(self, cfg)
|
111 |
+
|
112 |
+
self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
113 |
+
|
114 |
+
def __call__(self, batch):
|
115 |
+
# mel: (B, n_mels, T)
|
116 |
+
# wav (option): (B, T)
|
117 |
+
# text_input_ids: (B, L)
|
118 |
+
# text_attention_mask: (B, L)
|
119 |
+
|
120 |
+
packed_batch_features = dict()
|
121 |
+
|
122 |
+
for key in batch[0].keys():
|
123 |
+
if key == "melspec":
|
124 |
+
packed_batch_features["melspec"] = torch.from_numpy(
|
125 |
+
np.array([b["melspec"][:, :624] for b in batch])
|
126 |
+
)
|
127 |
+
|
128 |
+
if key == "wav":
|
129 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
130 |
+
packed_batch_features[key] = pad_sequence(
|
131 |
+
values, batch_first=True, padding_value=0
|
132 |
+
)
|
133 |
+
|
134 |
+
if key == "caption":
|
135 |
+
captions = [b[key] for b in batch]
|
136 |
+
text_input = self.tokenizer(
|
137 |
+
captions, return_tensors="pt", truncation=True, padding="longest"
|
138 |
+
)
|
139 |
+
text_input_ids = text_input["input_ids"]
|
140 |
+
text_attention_mask = text_input["attention_mask"]
|
141 |
+
|
142 |
+
packed_batch_features["text_input_ids"] = text_input_ids
|
143 |
+
packed_batch_features["text_attention_mask"] = text_attention_mask
|
144 |
+
|
145 |
+
return packed_batch_features
|
146 |
+
|
147 |
+
|
148 |
+
class AudioLDMTestDataset(BaseTestDataset):
|
149 |
+
...
|
150 |
+
|
151 |
+
|
152 |
+
class AudioLDMTestCollator(BaseTestCollator):
|
153 |
+
...
|
models/tta/ldm/audioldm_inference.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import torch.nn as nn
|
12 |
+
from collections import OrderedDict
|
13 |
+
import json
|
14 |
+
|
15 |
+
from models.tta.autoencoder.autoencoder import AutoencoderKL
|
16 |
+
from models.tta.ldm.inference_utils.vocoder import Generator
|
17 |
+
from models.tta.ldm.audioldm import AudioLDM
|
18 |
+
from transformers import T5EncoderModel, AutoTokenizer
|
19 |
+
from diffusers import PNDMScheduler
|
20 |
+
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
from scipy.io.wavfile import write
|
23 |
+
|
24 |
+
|
25 |
+
class AttrDict(dict):
|
26 |
+
def __init__(self, *args, **kwargs):
|
27 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
28 |
+
self.__dict__ = self
|
29 |
+
|
30 |
+
|
31 |
+
class AudioLDMInference:
|
32 |
+
def __init__(self, args, cfg):
|
33 |
+
self.cfg = cfg
|
34 |
+
self.args = args
|
35 |
+
|
36 |
+
self.build_autoencoderkl()
|
37 |
+
self.build_textencoder()
|
38 |
+
|
39 |
+
self.model = self.build_model()
|
40 |
+
self.load_state_dict()
|
41 |
+
|
42 |
+
self.build_vocoder()
|
43 |
+
|
44 |
+
self.out_path = self.args.output_dir
|
45 |
+
self.out_mel_path = os.path.join(self.out_path, "mel")
|
46 |
+
self.out_wav_path = os.path.join(self.out_path, "wav")
|
47 |
+
os.makedirs(self.out_mel_path, exist_ok=True)
|
48 |
+
os.makedirs(self.out_wav_path, exist_ok=True)
|
49 |
+
|
50 |
+
def build_autoencoderkl(self):
|
51 |
+
self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl)
|
52 |
+
self.autoencoder_path = self.cfg.model.autoencoder_path
|
53 |
+
checkpoint = torch.load(self.autoencoder_path, map_location="cpu")
|
54 |
+
self.autoencoderkl.load_state_dict(checkpoint["model"])
|
55 |
+
self.autoencoderkl.cuda(self.args.local_rank)
|
56 |
+
self.autoencoderkl.requires_grad_(requires_grad=False)
|
57 |
+
self.autoencoderkl.eval()
|
58 |
+
|
59 |
+
def build_textencoder(self):
|
60 |
+
self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
61 |
+
self.text_encoder = T5EncoderModel.from_pretrained("t5-base")
|
62 |
+
self.text_encoder.cuda(self.args.local_rank)
|
63 |
+
self.text_encoder.requires_grad_(requires_grad=False)
|
64 |
+
self.text_encoder.eval()
|
65 |
+
|
66 |
+
def build_vocoder(self):
|
67 |
+
config_file = os.path.join(self.args.vocoder_config_path)
|
68 |
+
with open(config_file) as f:
|
69 |
+
data = f.read()
|
70 |
+
json_config = json.loads(data)
|
71 |
+
h = AttrDict(json_config)
|
72 |
+
self.vocoder = Generator(h).to(self.args.local_rank)
|
73 |
+
checkpoint_dict = torch.load(
|
74 |
+
self.args.vocoder_path, map_location=self.args.local_rank
|
75 |
+
)
|
76 |
+
self.vocoder.load_state_dict(checkpoint_dict["generator"])
|
77 |
+
|
78 |
+
def build_model(self):
|
79 |
+
self.model = AudioLDM(self.cfg.model.audioldm)
|
80 |
+
return self.model
|
81 |
+
|
82 |
+
def load_state_dict(self):
|
83 |
+
self.checkpoint_path = self.args.checkpoint_path
|
84 |
+
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
|
85 |
+
self.model.load_state_dict(checkpoint["model"])
|
86 |
+
self.model.cuda(self.args.local_rank)
|
87 |
+
|
88 |
+
def get_text_embedding(self):
|
89 |
+
text = self.args.text
|
90 |
+
|
91 |
+
prompt = [text]
|
92 |
+
|
93 |
+
text_input = self.tokenizer(
|
94 |
+
prompt,
|
95 |
+
max_length=self.tokenizer.model_max_length,
|
96 |
+
truncation=True,
|
97 |
+
padding="do_not_pad",
|
98 |
+
return_tensors="pt",
|
99 |
+
)
|
100 |
+
text_embeddings = self.text_encoder(
|
101 |
+
text_input.input_ids.to(self.args.local_rank)
|
102 |
+
)[0]
|
103 |
+
|
104 |
+
max_length = text_input.input_ids.shape[-1]
|
105 |
+
uncond_input = self.tokenizer(
|
106 |
+
[""] * 1, padding="max_length", max_length=max_length, return_tensors="pt"
|
107 |
+
)
|
108 |
+
uncond_embeddings = self.text_encoder(
|
109 |
+
uncond_input.input_ids.to(self.args.local_rank)
|
110 |
+
)[0]
|
111 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
112 |
+
|
113 |
+
return text_embeddings
|
114 |
+
|
115 |
+
def inference(self):
|
116 |
+
text_embeddings = self.get_text_embedding()
|
117 |
+
print(text_embeddings.shape)
|
118 |
+
|
119 |
+
num_steps = self.args.num_steps
|
120 |
+
guidance_scale = self.args.guidance_scale
|
121 |
+
|
122 |
+
noise_scheduler = PNDMScheduler(
|
123 |
+
num_train_timesteps=1000,
|
124 |
+
beta_start=0.00085,
|
125 |
+
beta_end=0.012,
|
126 |
+
beta_schedule="scaled_linear",
|
127 |
+
skip_prk_steps=True,
|
128 |
+
set_alpha_to_one=False,
|
129 |
+
steps_offset=1,
|
130 |
+
prediction_type="epsilon",
|
131 |
+
)
|
132 |
+
|
133 |
+
noise_scheduler.set_timesteps(num_steps)
|
134 |
+
|
135 |
+
latents = torch.randn(
|
136 |
+
(
|
137 |
+
1,
|
138 |
+
self.cfg.model.autoencoderkl.z_channels,
|
139 |
+
80 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)),
|
140 |
+
624 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)),
|
141 |
+
)
|
142 |
+
).to(self.args.local_rank)
|
143 |
+
|
144 |
+
self.model.eval()
|
145 |
+
for t in tqdm(noise_scheduler.timesteps):
|
146 |
+
t = t.to(self.args.local_rank)
|
147 |
+
|
148 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
149 |
+
latent_model_input = torch.cat([latents] * 2)
|
150 |
+
|
151 |
+
latent_model_input = noise_scheduler.scale_model_input(
|
152 |
+
latent_model_input, timestep=t
|
153 |
+
)
|
154 |
+
# print(latent_model_input.shape)
|
155 |
+
|
156 |
+
# predict the noise residual
|
157 |
+
with torch.no_grad():
|
158 |
+
noise_pred = self.model(
|
159 |
+
latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings
|
160 |
+
)
|
161 |
+
|
162 |
+
# perform guidance
|
163 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
164 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
165 |
+
noise_pred_text - noise_pred_uncond
|
166 |
+
)
|
167 |
+
|
168 |
+
# compute the previous noisy sample x_t -> x_t-1
|
169 |
+
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
|
170 |
+
# print(latents.shape)
|
171 |
+
|
172 |
+
latents_out = latents
|
173 |
+
print(latents_out.shape)
|
174 |
+
|
175 |
+
with torch.no_grad():
|
176 |
+
mel_out = self.autoencoderkl.decode(latents_out)
|
177 |
+
print(mel_out.shape)
|
178 |
+
|
179 |
+
melspec = mel_out[0, 0].cpu().detach().numpy()
|
180 |
+
plt.imsave(os.path.join(self.out_mel_path, self.args.text + ".png"), melspec)
|
181 |
+
|
182 |
+
self.vocoder.eval()
|
183 |
+
self.vocoder.remove_weight_norm()
|
184 |
+
with torch.no_grad():
|
185 |
+
melspec = np.expand_dims(melspec, 0)
|
186 |
+
melspec = torch.FloatTensor(melspec).to(self.args.local_rank)
|
187 |
+
|
188 |
+
y = self.vocoder(melspec)
|
189 |
+
audio = y.squeeze()
|
190 |
+
audio = audio * 32768.0
|
191 |
+
audio = audio.cpu().numpy().astype("int16")
|
192 |
+
|
193 |
+
write(os.path.join(self.out_wav_path, self.args.text + ".wav"), 16000, audio)
|
models/tta/ldm/audioldm_trainer.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from models.base.base_trainer import BaseTrainer
|
7 |
+
from diffusers import DDPMScheduler
|
8 |
+
from models.tta.ldm.audioldm_dataset import AudioLDMDataset, AudioLDMCollator
|
9 |
+
from models.tta.autoencoder.autoencoder import AutoencoderKL
|
10 |
+
from models.tta.ldm.audioldm import AudioLDM, UNetModel
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import MSELoss, L1Loss
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
16 |
+
|
17 |
+
from transformers import T5EncoderModel
|
18 |
+
from diffusers import DDPMScheduler
|
19 |
+
|
20 |
+
|
21 |
+
class AudioLDMTrainer(BaseTrainer):
|
22 |
+
def __init__(self, args, cfg):
|
23 |
+
BaseTrainer.__init__(self, args, cfg)
|
24 |
+
self.cfg = cfg
|
25 |
+
|
26 |
+
self.build_autoencoderkl()
|
27 |
+
self.build_textencoder()
|
28 |
+
self.nosie_scheduler = self.build_noise_scheduler()
|
29 |
+
|
30 |
+
self.save_config_file()
|
31 |
+
|
32 |
+
def build_autoencoderkl(self):
|
33 |
+
self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl)
|
34 |
+
self.autoencoder_path = self.cfg.model.autoencoder_path
|
35 |
+
checkpoint = torch.load(self.autoencoder_path, map_location="cpu")
|
36 |
+
self.autoencoderkl.load_state_dict(checkpoint["model"])
|
37 |
+
self.autoencoderkl.cuda(self.args.local_rank)
|
38 |
+
self.autoencoderkl.requires_grad_(requires_grad=False)
|
39 |
+
self.autoencoderkl.eval()
|
40 |
+
|
41 |
+
def build_textencoder(self):
|
42 |
+
self.text_encoder = T5EncoderModel.from_pretrained("t5-base")
|
43 |
+
self.text_encoder.cuda(self.args.local_rank)
|
44 |
+
self.text_encoder.requires_grad_(requires_grad=False)
|
45 |
+
self.text_encoder.eval()
|
46 |
+
|
47 |
+
def build_noise_scheduler(self):
|
48 |
+
nosie_scheduler = DDPMScheduler(
|
49 |
+
num_train_timesteps=self.cfg.model.noise_scheduler.num_train_timesteps,
|
50 |
+
beta_start=self.cfg.model.noise_scheduler.beta_start,
|
51 |
+
beta_end=self.cfg.model.noise_scheduler.beta_end,
|
52 |
+
beta_schedule=self.cfg.model.noise_scheduler.beta_schedule,
|
53 |
+
clip_sample=self.cfg.model.noise_scheduler.clip_sample,
|
54 |
+
# steps_offset=self.cfg.model.noise_scheduler.steps_offset,
|
55 |
+
# set_alpha_to_one=self.cfg.model.noise_scheduler.set_alpha_to_one,
|
56 |
+
# skip_prk_steps=self.cfg.model.noise_scheduler.skip_prk_steps,
|
57 |
+
prediction_type=self.cfg.model.noise_scheduler.prediction_type,
|
58 |
+
)
|
59 |
+
return nosie_scheduler
|
60 |
+
|
61 |
+
def build_dataset(self):
|
62 |
+
return AudioLDMDataset, AudioLDMCollator
|
63 |
+
|
64 |
+
def build_data_loader(self):
|
65 |
+
Dataset, Collator = self.build_dataset()
|
66 |
+
# build dataset instance for each dataset and combine them by ConcatDataset
|
67 |
+
datasets_list = []
|
68 |
+
for dataset in self.cfg.dataset:
|
69 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
70 |
+
datasets_list.append(subdataset)
|
71 |
+
train_dataset = ConcatDataset(datasets_list)
|
72 |
+
|
73 |
+
train_collate = Collator(self.cfg)
|
74 |
+
|
75 |
+
# use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
|
76 |
+
train_loader = DataLoader(
|
77 |
+
train_dataset,
|
78 |
+
collate_fn=train_collate,
|
79 |
+
num_workers=self.args.num_workers,
|
80 |
+
batch_size=self.cfg.train.batch_size,
|
81 |
+
pin_memory=False,
|
82 |
+
)
|
83 |
+
if not self.cfg.train.ddp or self.args.local_rank == 0:
|
84 |
+
datasets_list = []
|
85 |
+
for dataset in self.cfg.dataset:
|
86 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
87 |
+
datasets_list.append(subdataset)
|
88 |
+
valid_dataset = ConcatDataset(datasets_list)
|
89 |
+
valid_collate = Collator(self.cfg)
|
90 |
+
|
91 |
+
valid_loader = DataLoader(
|
92 |
+
valid_dataset,
|
93 |
+
collate_fn=valid_collate,
|
94 |
+
num_workers=1,
|
95 |
+
batch_size=self.cfg.train.batch_size,
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
raise NotImplementedError("DDP is not supported yet.")
|
99 |
+
# valid_loader = None
|
100 |
+
data_loader = {"train": train_loader, "valid": valid_loader}
|
101 |
+
return data_loader
|
102 |
+
|
103 |
+
def build_optimizer(self):
|
104 |
+
optimizer = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam)
|
105 |
+
return optimizer
|
106 |
+
|
107 |
+
# TODO: check it...
|
108 |
+
def build_scheduler(self):
|
109 |
+
return None
|
110 |
+
# return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau)
|
111 |
+
|
112 |
+
def write_summary(self, losses, stats):
|
113 |
+
for key, value in losses.items():
|
114 |
+
self.sw.add_scalar(key, value, self.step)
|
115 |
+
|
116 |
+
def write_valid_summary(self, losses, stats):
|
117 |
+
for key, value in losses.items():
|
118 |
+
self.sw.add_scalar(key, value, self.step)
|
119 |
+
|
120 |
+
def build_criterion(self):
|
121 |
+
criterion = nn.MSELoss(reduction="mean")
|
122 |
+
return criterion
|
123 |
+
|
124 |
+
def get_state_dict(self):
|
125 |
+
if self.scheduler != None:
|
126 |
+
state_dict = {
|
127 |
+
"model": self.model.state_dict(),
|
128 |
+
"optimizer": self.optimizer.state_dict(),
|
129 |
+
"scheduler": self.scheduler.state_dict(),
|
130 |
+
"step": self.step,
|
131 |
+
"epoch": self.epoch,
|
132 |
+
"batch_size": self.cfg.train.batch_size,
|
133 |
+
}
|
134 |
+
else:
|
135 |
+
state_dict = {
|
136 |
+
"model": self.model.state_dict(),
|
137 |
+
"optimizer": self.optimizer.state_dict(),
|
138 |
+
"step": self.step,
|
139 |
+
"epoch": self.epoch,
|
140 |
+
"batch_size": self.cfg.train.batch_size,
|
141 |
+
}
|
142 |
+
return state_dict
|
143 |
+
|
144 |
+
def load_model(self, checkpoint):
|
145 |
+
self.step = checkpoint["step"]
|
146 |
+
self.epoch = checkpoint["epoch"]
|
147 |
+
|
148 |
+
self.model.load_state_dict(checkpoint["model"])
|
149 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
150 |
+
if self.scheduler != None:
|
151 |
+
self.scheduler.load_state_dict(checkpoint["scheduler"])
|
152 |
+
|
153 |
+
def build_model(self):
|
154 |
+
self.model = AudioLDM(self.cfg.model.audioldm)
|
155 |
+
return self.model
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def mel_to_latent(self, melspec):
|
159 |
+
posterior = self.autoencoderkl.encode(melspec)
|
160 |
+
latent = posterior.sample() # (B, 4, 5, 78)
|
161 |
+
return latent
|
162 |
+
|
163 |
+
@torch.no_grad()
|
164 |
+
def get_text_embedding(self, text_input_ids, text_attention_mask):
|
165 |
+
text_embedding = self.text_encoder(
|
166 |
+
input_ids=text_input_ids, attention_mask=text_attention_mask
|
167 |
+
).last_hidden_state
|
168 |
+
return text_embedding # (B, T, 768)
|
169 |
+
|
170 |
+
def train_step(self, data):
|
171 |
+
train_losses = {}
|
172 |
+
total_loss = 0
|
173 |
+
train_stats = {}
|
174 |
+
|
175 |
+
melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
|
176 |
+
latents = self.mel_to_latent(melspec)
|
177 |
+
|
178 |
+
text_embedding = self.get_text_embedding(
|
179 |
+
data["text_input_ids"], data["text_attention_mask"]
|
180 |
+
)
|
181 |
+
|
182 |
+
noise = torch.randn_like(latents).float()
|
183 |
+
|
184 |
+
bsz = latents.shape[0]
|
185 |
+
timesteps = torch.randint(
|
186 |
+
0,
|
187 |
+
self.cfg.model.noise_scheduler.num_train_timesteps,
|
188 |
+
(bsz,),
|
189 |
+
device=latents.device,
|
190 |
+
)
|
191 |
+
timesteps = timesteps.long()
|
192 |
+
|
193 |
+
with torch.no_grad():
|
194 |
+
noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps)
|
195 |
+
|
196 |
+
model_pred = self.model(
|
197 |
+
noisy_latents, timesteps=timesteps, context=text_embedding
|
198 |
+
)
|
199 |
+
|
200 |
+
loss = self.criterion(model_pred, noise)
|
201 |
+
|
202 |
+
train_losses["loss"] = loss
|
203 |
+
total_loss += loss
|
204 |
+
|
205 |
+
self.optimizer.zero_grad()
|
206 |
+
total_loss.backward()
|
207 |
+
self.optimizer.step()
|
208 |
+
|
209 |
+
for item in train_losses:
|
210 |
+
train_losses[item] = train_losses[item].item()
|
211 |
+
|
212 |
+
return train_losses, train_stats, total_loss.item()
|
213 |
+
|
214 |
+
# TODO: eval step
|
215 |
+
@torch.no_grad()
|
216 |
+
def eval_step(self, data, index):
|
217 |
+
valid_loss = {}
|
218 |
+
total_valid_loss = 0
|
219 |
+
valid_stats = {}
|
220 |
+
|
221 |
+
melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
|
222 |
+
latents = self.mel_to_latent(melspec)
|
223 |
+
|
224 |
+
text_embedding = self.get_text_embedding(
|
225 |
+
data["text_input_ids"], data["text_attention_mask"]
|
226 |
+
)
|
227 |
+
|
228 |
+
noise = torch.randn_like(latents).float()
|
229 |
+
|
230 |
+
bsz = latents.shape[0]
|
231 |
+
timesteps = torch.randint(
|
232 |
+
0,
|
233 |
+
self.cfg.model.noise_scheduler.num_train_timesteps,
|
234 |
+
(bsz,),
|
235 |
+
device=latents.device,
|
236 |
+
)
|
237 |
+
timesteps = timesteps.long()
|
238 |
+
|
239 |
+
noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps)
|
240 |
+
|
241 |
+
model_pred = self.model(noisy_latents, timesteps, text_embedding)
|
242 |
+
|
243 |
+
loss = self.criterion(model_pred, noise)
|
244 |
+
valid_loss["loss"] = loss
|
245 |
+
|
246 |
+
total_valid_loss += loss
|
247 |
+
|
248 |
+
for item in valid_loss:
|
249 |
+
valid_loss[item] = valid_loss[item].item()
|
250 |
+
|
251 |
+
return valid_loss, valid_stats, total_valid_loss.item()
|
models/tta/ldm/inference_utils/utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import glob
|
7 |
+
import os
|
8 |
+
import matplotlib
|
9 |
+
import torch
|
10 |
+
from torch.nn.utils import weight_norm
|
11 |
+
|
12 |
+
matplotlib.use("Agg")
|
13 |
+
import matplotlib.pylab as plt
|
14 |
+
|
15 |
+
|
16 |
+
def plot_spectrogram(spectrogram):
|
17 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
18 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
19 |
+
plt.colorbar(im, ax=ax)
|
20 |
+
|
21 |
+
fig.canvas.draw()
|
22 |
+
plt.close()
|
23 |
+
|
24 |
+
return fig
|
25 |
+
|
26 |
+
|
27 |
+
def init_weights(m, mean=0.0, std=0.01):
|
28 |
+
classname = m.__class__.__name__
|
29 |
+
if classname.find("Conv") != -1:
|
30 |
+
m.weight.data.normal_(mean, std)
|
31 |
+
|
32 |
+
|
33 |
+
def apply_weight_norm(m):
|
34 |
+
classname = m.__class__.__name__
|
35 |
+
if classname.find("Conv") != -1:
|
36 |
+
weight_norm(m)
|
37 |
+
|
38 |
+
|
39 |
+
def get_padding(kernel_size, dilation=1):
|
40 |
+
return int((kernel_size * dilation - dilation) / 2)
|
41 |
+
|
42 |
+
|
43 |
+
def load_checkpoint(filepath, device):
|
44 |
+
assert os.path.isfile(filepath)
|
45 |
+
print("Loading '{}'".format(filepath))
|
46 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
47 |
+
print("Complete.")
|
48 |
+
return checkpoint_dict
|
49 |
+
|
50 |
+
|
51 |
+
def save_checkpoint(filepath, obj):
|
52 |
+
print("Saving checkpoint to {}".format(filepath))
|
53 |
+
torch.save(obj, filepath)
|
54 |
+
print("Complete.")
|
55 |
+
|
56 |
+
|
57 |
+
def scan_checkpoint(cp_dir, prefix):
|
58 |
+
pattern = os.path.join(cp_dir, prefix + "????????")
|
59 |
+
cp_list = glob.glob(pattern)
|
60 |
+
if len(cp_list) == 0:
|
61 |
+
return None
|
62 |
+
return sorted(cp_list)[-1]
|
models/tta/ldm/inference_utils/vocoder.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
10 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
11 |
+
from models.tta.ldm.inference_utils.utils import get_padding, init_weights
|
12 |
+
|
13 |
+
LRELU_SLOPE = 0.1
|
14 |
+
|
15 |
+
|
16 |
+
class ResBlock1(torch.nn.Module):
|
17 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
18 |
+
super(ResBlock1, self).__init__()
|
19 |
+
self.h = h
|
20 |
+
self.convs1 = nn.ModuleList(
|
21 |
+
[
|
22 |
+
weight_norm(
|
23 |
+
Conv1d(
|
24 |
+
channels,
|
25 |
+
channels,
|
26 |
+
kernel_size,
|
27 |
+
1,
|
28 |
+
dilation=dilation[0],
|
29 |
+
padding=get_padding(kernel_size, dilation[0]),
|
30 |
+
)
|
31 |
+
),
|
32 |
+
weight_norm(
|
33 |
+
Conv1d(
|
34 |
+
channels,
|
35 |
+
channels,
|
36 |
+
kernel_size,
|
37 |
+
1,
|
38 |
+
dilation=dilation[1],
|
39 |
+
padding=get_padding(kernel_size, dilation[1]),
|
40 |
+
)
|
41 |
+
),
|
42 |
+
weight_norm(
|
43 |
+
Conv1d(
|
44 |
+
channels,
|
45 |
+
channels,
|
46 |
+
kernel_size,
|
47 |
+
1,
|
48 |
+
dilation=dilation[2],
|
49 |
+
padding=get_padding(kernel_size, dilation[2]),
|
50 |
+
)
|
51 |
+
),
|
52 |
+
]
|
53 |
+
)
|
54 |
+
self.convs1.apply(init_weights)
|
55 |
+
|
56 |
+
self.convs2 = nn.ModuleList(
|
57 |
+
[
|
58 |
+
weight_norm(
|
59 |
+
Conv1d(
|
60 |
+
channels,
|
61 |
+
channels,
|
62 |
+
kernel_size,
|
63 |
+
1,
|
64 |
+
dilation=1,
|
65 |
+
padding=get_padding(kernel_size, 1),
|
66 |
+
)
|
67 |
+
),
|
68 |
+
weight_norm(
|
69 |
+
Conv1d(
|
70 |
+
channels,
|
71 |
+
channels,
|
72 |
+
kernel_size,
|
73 |
+
1,
|
74 |
+
dilation=1,
|
75 |
+
padding=get_padding(kernel_size, 1),
|
76 |
+
)
|
77 |
+
),
|
78 |
+
weight_norm(
|
79 |
+
Conv1d(
|
80 |
+
channels,
|
81 |
+
channels,
|
82 |
+
kernel_size,
|
83 |
+
1,
|
84 |
+
dilation=1,
|
85 |
+
padding=get_padding(kernel_size, 1),
|
86 |
+
)
|
87 |
+
),
|
88 |
+
]
|
89 |
+
)
|
90 |
+
self.convs2.apply(init_weights)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
94 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
95 |
+
xt = c1(xt)
|
96 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
97 |
+
xt = c2(xt)
|
98 |
+
x = xt + x
|
99 |
+
return x
|
100 |
+
|
101 |
+
def remove_weight_norm(self):
|
102 |
+
for l in self.convs1:
|
103 |
+
remove_weight_norm(l)
|
104 |
+
for l in self.convs2:
|
105 |
+
remove_weight_norm(l)
|
106 |
+
|
107 |
+
|
108 |
+
class ResBlock2(torch.nn.Module):
|
109 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
110 |
+
super(ResBlock2, self).__init__()
|
111 |
+
self.h = h
|
112 |
+
self.convs = nn.ModuleList(
|
113 |
+
[
|
114 |
+
weight_norm(
|
115 |
+
Conv1d(
|
116 |
+
channels,
|
117 |
+
channels,
|
118 |
+
kernel_size,
|
119 |
+
1,
|
120 |
+
dilation=dilation[0],
|
121 |
+
padding=get_padding(kernel_size, dilation[0]),
|
122 |
+
)
|
123 |
+
),
|
124 |
+
weight_norm(
|
125 |
+
Conv1d(
|
126 |
+
channels,
|
127 |
+
channels,
|
128 |
+
kernel_size,
|
129 |
+
1,
|
130 |
+
dilation=dilation[1],
|
131 |
+
padding=get_padding(kernel_size, dilation[1]),
|
132 |
+
)
|
133 |
+
),
|
134 |
+
]
|
135 |
+
)
|
136 |
+
self.convs.apply(init_weights)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
for c in self.convs:
|
140 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
141 |
+
xt = c(xt)
|
142 |
+
x = xt + x
|
143 |
+
return x
|
144 |
+
|
145 |
+
def remove_weight_norm(self):
|
146 |
+
for l in self.convs:
|
147 |
+
remove_weight_norm(l)
|
148 |
+
|
149 |
+
|
150 |
+
class Generator(torch.nn.Module):
|
151 |
+
def __init__(self, h):
|
152 |
+
super(Generator, self).__init__()
|
153 |
+
self.h = h
|
154 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
155 |
+
self.num_upsamples = len(h.upsample_rates)
|
156 |
+
self.conv_pre = weight_norm(
|
157 |
+
Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
|
158 |
+
)
|
159 |
+
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
160 |
+
|
161 |
+
self.ups = nn.ModuleList()
|
162 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
163 |
+
self.ups.append(
|
164 |
+
weight_norm(
|
165 |
+
ConvTranspose1d(
|
166 |
+
h.upsample_initial_channel // (2**i),
|
167 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
168 |
+
k,
|
169 |
+
u,
|
170 |
+
padding=(k - u) // 2,
|
171 |
+
)
|
172 |
+
)
|
173 |
+
)
|
174 |
+
|
175 |
+
self.resblocks = nn.ModuleList()
|
176 |
+
for i in range(len(self.ups)):
|
177 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
178 |
+
for j, (k, d) in enumerate(
|
179 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
180 |
+
):
|
181 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
182 |
+
|
183 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
184 |
+
self.ups.apply(init_weights)
|
185 |
+
self.conv_post.apply(init_weights)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
x = self.conv_pre(x)
|
189 |
+
for i in range(self.num_upsamples):
|
190 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
191 |
+
x = self.ups[i](x)
|
192 |
+
xs = None
|
193 |
+
for j in range(self.num_kernels):
|
194 |
+
if xs is None:
|
195 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
196 |
+
else:
|
197 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
198 |
+
x = xs / self.num_kernels
|
199 |
+
x = F.leaky_relu(x)
|
200 |
+
x = self.conv_post(x)
|
201 |
+
x = torch.tanh(x)
|
202 |
+
|
203 |
+
return x
|
204 |
+
|
205 |
+
def remove_weight_norm(self):
|
206 |
+
print("Removing weight norm...")
|
207 |
+
for l in self.ups:
|
208 |
+
remove_weight_norm(l)
|
209 |
+
for l in self.resblocks:
|
210 |
+
l.remove_weight_norm()
|
211 |
+
remove_weight_norm(self.conv_pre)
|
212 |
+
remove_weight_norm(self.conv_post)
|
213 |
+
|
214 |
+
|
215 |
+
class DiscriminatorP(torch.nn.Module):
|
216 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
217 |
+
super(DiscriminatorP, self).__init__()
|
218 |
+
self.period = period
|
219 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
220 |
+
self.convs = nn.ModuleList(
|
221 |
+
[
|
222 |
+
norm_f(
|
223 |
+
Conv2d(
|
224 |
+
1,
|
225 |
+
32,
|
226 |
+
(kernel_size, 1),
|
227 |
+
(stride, 1),
|
228 |
+
padding=(get_padding(5, 1), 0),
|
229 |
+
)
|
230 |
+
),
|
231 |
+
norm_f(
|
232 |
+
Conv2d(
|
233 |
+
32,
|
234 |
+
128,
|
235 |
+
(kernel_size, 1),
|
236 |
+
(stride, 1),
|
237 |
+
padding=(get_padding(5, 1), 0),
|
238 |
+
)
|
239 |
+
),
|
240 |
+
norm_f(
|
241 |
+
Conv2d(
|
242 |
+
128,
|
243 |
+
512,
|
244 |
+
(kernel_size, 1),
|
245 |
+
(stride, 1),
|
246 |
+
padding=(get_padding(5, 1), 0),
|
247 |
+
)
|
248 |
+
),
|
249 |
+
norm_f(
|
250 |
+
Conv2d(
|
251 |
+
512,
|
252 |
+
1024,
|
253 |
+
(kernel_size, 1),
|
254 |
+
(stride, 1),
|
255 |
+
padding=(get_padding(5, 1), 0),
|
256 |
+
)
|
257 |
+
),
|
258 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
259 |
+
]
|
260 |
+
)
|
261 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
fmap = []
|
265 |
+
|
266 |
+
# 1d to 2d
|
267 |
+
b, c, t = x.shape
|
268 |
+
if t % self.period != 0: # pad first
|
269 |
+
n_pad = self.period - (t % self.period)
|
270 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
271 |
+
t = t + n_pad
|
272 |
+
x = x.view(b, c, t // self.period, self.period)
|
273 |
+
|
274 |
+
for l in self.convs:
|
275 |
+
x = l(x)
|
276 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
277 |
+
fmap.append(x)
|
278 |
+
x = self.conv_post(x)
|
279 |
+
fmap.append(x)
|
280 |
+
x = torch.flatten(x, 1, -1)
|
281 |
+
|
282 |
+
return x, fmap
|
283 |
+
|
284 |
+
|
285 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
286 |
+
def __init__(self):
|
287 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
288 |
+
self.discriminators = nn.ModuleList(
|
289 |
+
[
|
290 |
+
DiscriminatorP(2),
|
291 |
+
DiscriminatorP(3),
|
292 |
+
DiscriminatorP(5),
|
293 |
+
DiscriminatorP(7),
|
294 |
+
DiscriminatorP(11),
|
295 |
+
]
|
296 |
+
)
|
297 |
+
|
298 |
+
def forward(self, y, y_hat):
|
299 |
+
y_d_rs = []
|
300 |
+
y_d_gs = []
|
301 |
+
fmap_rs = []
|
302 |
+
fmap_gs = []
|
303 |
+
for i, d in enumerate(self.discriminators):
|
304 |
+
y_d_r, fmap_r = d(y)
|
305 |
+
y_d_g, fmap_g = d(y_hat)
|
306 |
+
y_d_rs.append(y_d_r)
|
307 |
+
fmap_rs.append(fmap_r)
|
308 |
+
y_d_gs.append(y_d_g)
|
309 |
+
fmap_gs.append(fmap_g)
|
310 |
+
|
311 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
312 |
+
|
313 |
+
|
314 |
+
class DiscriminatorS(torch.nn.Module):
|
315 |
+
def __init__(self, use_spectral_norm=False):
|
316 |
+
super(DiscriminatorS, self).__init__()
|
317 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
318 |
+
self.convs = nn.ModuleList(
|
319 |
+
[
|
320 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
321 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
322 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
323 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
324 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
325 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
326 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
327 |
+
]
|
328 |
+
)
|
329 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
330 |
+
|
331 |
+
def forward(self, x):
|
332 |
+
fmap = []
|
333 |
+
for l in self.convs:
|
334 |
+
x = l(x)
|
335 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
336 |
+
fmap.append(x)
|
337 |
+
x = self.conv_post(x)
|
338 |
+
fmap.append(x)
|
339 |
+
x = torch.flatten(x, 1, -1)
|
340 |
+
|
341 |
+
return x, fmap
|
342 |
+
|
343 |
+
|
344 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
345 |
+
def __init__(self):
|
346 |
+
super(MultiScaleDiscriminator, self).__init__()
|
347 |
+
self.discriminators = nn.ModuleList(
|
348 |
+
[
|
349 |
+
DiscriminatorS(use_spectral_norm=True),
|
350 |
+
DiscriminatorS(),
|
351 |
+
DiscriminatorS(),
|
352 |
+
]
|
353 |
+
)
|
354 |
+
self.meanpools = nn.ModuleList(
|
355 |
+
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
356 |
+
)
|
357 |
+
|
358 |
+
def forward(self, y, y_hat):
|
359 |
+
y_d_rs = []
|
360 |
+
y_d_gs = []
|
361 |
+
fmap_rs = []
|
362 |
+
fmap_gs = []
|
363 |
+
for i, d in enumerate(self.discriminators):
|
364 |
+
if i != 0:
|
365 |
+
y = self.meanpools[i - 1](y)
|
366 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
367 |
+
y_d_r, fmap_r = d(y)
|
368 |
+
y_d_g, fmap_g = d(y_hat)
|
369 |
+
y_d_rs.append(y_d_r)
|
370 |
+
fmap_rs.append(fmap_r)
|
371 |
+
y_d_gs.append(y_d_g)
|
372 |
+
fmap_gs.append(fmap_g)
|
373 |
+
|
374 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
375 |
+
|
376 |
+
|
377 |
+
def feature_loss(fmap_r, fmap_g):
|
378 |
+
loss = 0
|
379 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
380 |
+
for rl, gl in zip(dr, dg):
|
381 |
+
loss += torch.mean(torch.abs(rl - gl))
|
382 |
+
|
383 |
+
return loss * 2
|
384 |
+
|
385 |
+
|
386 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
387 |
+
loss = 0
|
388 |
+
r_losses = []
|
389 |
+
g_losses = []
|
390 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
391 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
392 |
+
g_loss = torch.mean(dg**2)
|
393 |
+
loss += r_loss + g_loss
|
394 |
+
r_losses.append(r_loss.item())
|
395 |
+
g_losses.append(g_loss.item())
|
396 |
+
|
397 |
+
return loss, r_losses, g_losses
|
398 |
+
|
399 |
+
|
400 |
+
def generator_loss(disc_outputs):
|
401 |
+
loss = 0
|
402 |
+
gen_losses = []
|
403 |
+
for dg in disc_outputs:
|
404 |
+
l = torch.mean((1 - dg) ** 2)
|
405 |
+
gen_losses.append(l)
|
406 |
+
loss += l
|
407 |
+
|
408 |
+
return loss, gen_losses
|
models/tts/naturalspeech2/ns2_dataset.py
CHANGED
@@ -21,13 +21,11 @@ class NS2Dataset(torch.utils.data.Dataset):
|
|
21 |
assert isinstance(dataset, str)
|
22 |
|
23 |
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
24 |
-
# for example: /home/v-detaixin/LibriTTS/processed_data; train-full
|
25 |
|
26 |
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
27 |
# train.json
|
28 |
|
29 |
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
30 |
-
# /home/v-detaixin/LibriTTS/processed_data/train-full/train.json
|
31 |
|
32 |
self.metadata = self.get_metadata()
|
33 |
|
|
|
21 |
assert isinstance(dataset, str)
|
22 |
|
23 |
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
|
|
24 |
|
25 |
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
26 |
# train.json
|
27 |
|
28 |
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
|
|
29 |
|
30 |
self.metadata = self.get_metadata()
|
31 |
|
models/vocoders/autoregressive/autoregressive_vocoder_dataset.py
ADDED
File without changes
|
models/vocoders/autoregressive/autoregressive_vocoder_inference.py
ADDED
File without changes
|