Spaces:
Runtime error
Runtime error
Shivam Mehta
commited on
Commit
·
3c10b34
1
Parent(s):
f5a235a
Adding code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Makefile +30 -0
- app.py +253 -0
- diff_ttsg/__init__.py +0 -0
- diff_ttsg/__pycache__/__init__.cpython-310.pyc +0 -0
- diff_ttsg/data/__init__.py +0 -0
- diff_ttsg/data/__pycache__/__init__.cpython-310.pyc +0 -0
- diff_ttsg/data/__pycache__/cormac_datamodule.cpython-310.pyc +0 -0
- diff_ttsg/data/components/__init__.py +0 -0
- diff_ttsg/data/cormac_datamodule.py +214 -0
- diff_ttsg/data/mnist_datamodule.py +130 -0
- diff_ttsg/eval.py +93 -0
- diff_ttsg/hifigan/LICENSE +21 -0
- diff_ttsg/hifigan/README.md +105 -0
- diff_ttsg/hifigan/__init__.py +0 -0
- diff_ttsg/hifigan/__pycache__/__init__.cpython-310.pyc +0 -0
- diff_ttsg/hifigan/__pycache__/config.cpython-310.pyc +0 -0
- diff_ttsg/hifigan/__pycache__/denoiser.cpython-310.pyc +0 -0
- diff_ttsg/hifigan/__pycache__/env.cpython-310.pyc +0 -0
- diff_ttsg/hifigan/__pycache__/models.cpython-310.pyc +0 -0
- diff_ttsg/hifigan/__pycache__/xutils.cpython-310.pyc +0 -0
- diff_ttsg/hifigan/config.py +38 -0
- diff_ttsg/hifigan/denoiser.py +64 -0
- diff_ttsg/hifigan/env.py +17 -0
- diff_ttsg/hifigan/meldataset.py +171 -0
- diff_ttsg/hifigan/models.py +286 -0
- diff_ttsg/hifigan/xutils.py +60 -0
- diff_ttsg/models/__init__.py +0 -0
- diff_ttsg/models/__pycache__/__init__.cpython-310.pyc +0 -0
- diff_ttsg/models/__pycache__/diff_ttsg.cpython-310.pyc +0 -0
- diff_ttsg/models/components/__init__.py +0 -0
- diff_ttsg/models/components/__pycache__/__init__.cpython-310.pyc +0 -0
- diff_ttsg/models/components/__pycache__/diffusion.cpython-310.pyc +0 -0
- diff_ttsg/models/components/__pycache__/text_encoder.cpython-310.pyc +0 -0
- diff_ttsg/models/components/__pycache__/transformer.cpython-310.pyc +0 -0
- diff_ttsg/models/components/diffusion.py +376 -0
- diff_ttsg/models/components/text_encoder.py +384 -0
- diff_ttsg/models/components/transformer.py +250 -0
- diff_ttsg/models/diff_ttsg.py +376 -0
- diff_ttsg/models/mnist_module.py +137 -0
- diff_ttsg/resources/cmu_dictionary +0 -0
- diff_ttsg/text/LICENSE +30 -0
- diff_ttsg/text/__init__.py +96 -0
- diff_ttsg/text/__pycache__/__init__.cpython-310.pyc +0 -0
- diff_ttsg/text/__pycache__/cleaners.cpython-310.pyc +0 -0
- diff_ttsg/text/__pycache__/cmudict.cpython-310.pyc +0 -0
- diff_ttsg/text/__pycache__/numbers.cpython-310.pyc +0 -0
- diff_ttsg/text/__pycache__/symbols.cpython-310.pyc +0 -0
- diff_ttsg/text/cleaners.py +73 -0
- diff_ttsg/text/cmudict.py +60 -0
- diff_ttsg/text/numbers.py +72 -0
Makefile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
help: ## Show help
|
3 |
+
@grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
4 |
+
|
5 |
+
clean: ## Clean autogenerated files
|
6 |
+
rm -rf dist
|
7 |
+
find . -type f -name "*.DS_Store" -ls -delete
|
8 |
+
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
|
9 |
+
find . | grep -E ".pytest_cache" | xargs rm -rf
|
10 |
+
find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
|
11 |
+
rm -f .coverage
|
12 |
+
|
13 |
+
clean-logs: ## Clean logs
|
14 |
+
rm -rf logs/**
|
15 |
+
|
16 |
+
format: ## Run pre-commit hooks
|
17 |
+
pre-commit run -a
|
18 |
+
|
19 |
+
sync: ## Merge changes from main branch to your current branch
|
20 |
+
git pull
|
21 |
+
git pull origin main
|
22 |
+
|
23 |
+
test: ## Run not slow tests
|
24 |
+
pytest -k "not slow"
|
25 |
+
|
26 |
+
test-full: ## Run all tests
|
27 |
+
pytest
|
28 |
+
|
29 |
+
train: ## Train the model
|
30 |
+
python diff_ttsg/train.py run_name=dev
|
app.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime as dt
|
3 |
+
import warnings
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import ffmpeg
|
7 |
+
import gradio as gr
|
8 |
+
import IPython.display as ipd
|
9 |
+
import joblib as jl
|
10 |
+
import numpy as np
|
11 |
+
import soundfile as sf
|
12 |
+
import torch
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
|
15 |
+
from diff_ttsg.hifigan.config import v1
|
16 |
+
from diff_ttsg.hifigan.denoiser import Denoiser
|
17 |
+
from diff_ttsg.hifigan.env import AttrDict
|
18 |
+
from diff_ttsg.hifigan.models import Generator as HiFiGAN
|
19 |
+
from diff_ttsg.models.diff_ttsg import Diff_TTSG
|
20 |
+
from diff_ttsg.text import cmudict, sequence_to_text, text_to_sequence
|
21 |
+
from diff_ttsg.text.symbols import symbols
|
22 |
+
from diff_ttsg.utils.model import denormalize
|
23 |
+
from diff_ttsg.utils.utils import intersperse, plot_tensor
|
24 |
+
from pymo.preprocessing import MocapParameterizer
|
25 |
+
from pymo.viz_tools import render_mp4
|
26 |
+
from pymo.writers import BVHWriter
|
27 |
+
|
28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
|
30 |
+
DIFF_TTSG_CHECKPOINT = "diff_ttsg_checkpoint.ckpt"
|
31 |
+
HIFIGAN_CHECKPOINT = "g_02500000"
|
32 |
+
MOTION_PIPELINE = "diff_ttsg/resources/data_pipe.expmap_86.1328125fps.sav"
|
33 |
+
CMU_DICT_PATH = "diff_ttsg/resources/cmu_dictionary"
|
34 |
+
|
35 |
+
OUTPUT_FOLDER = "synth_output"
|
36 |
+
|
37 |
+
# Model loading tools
|
38 |
+
def load_model(checkpoint_path):
|
39 |
+
model = Diff_TTSG.load_from_checkpoint(checkpoint_path, map_location=device)
|
40 |
+
model.eval()
|
41 |
+
return model
|
42 |
+
|
43 |
+
# Vocoder loading tools
|
44 |
+
def load_vocoder(checkpoint_path):
|
45 |
+
h = AttrDict(v1)
|
46 |
+
hifigan = HiFiGAN(h).to(device)
|
47 |
+
hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])
|
48 |
+
_ = hifigan.eval()
|
49 |
+
hifigan.remove_weight_norm()
|
50 |
+
return hifigan
|
51 |
+
|
52 |
+
# Setup text preprocessing
|
53 |
+
cmu = cmudict.CMUDict(CMU_DICT_PATH)
|
54 |
+
def process_text(text: str):
|
55 |
+
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols))).to(device)[None]
|
56 |
+
x_lengths = torch.LongTensor([x.shape[-1]]).cuda()
|
57 |
+
x_phones = sequence_to_text(x.squeeze(0).tolist())
|
58 |
+
return {
|
59 |
+
'x_orig': text,
|
60 |
+
'x': x,
|
61 |
+
'x_lengths': x_lengths,
|
62 |
+
'x_phones': x_phones
|
63 |
+
}
|
64 |
+
|
65 |
+
# Setup motion visualisation
|
66 |
+
motion_pipeline = jl.load(MOTION_PIPELINE)
|
67 |
+
bvh_writer = BVHWriter()
|
68 |
+
mocap_params = MocapParameterizer("position")
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
## Load models
|
73 |
+
|
74 |
+
model = load_model(DIFF_TTSG_CHECKPOINT)
|
75 |
+
vocoder = load_vocoder(HIFIGAN_CHECKPOINT)
|
76 |
+
denoiser = Denoiser(vocoder, mode='zeros')
|
77 |
+
|
78 |
+
|
79 |
+
# Synthesis functions
|
80 |
+
|
81 |
+
@torch.inference_mode()
|
82 |
+
def synthesise(text, mel_timestep, motion_timestep, length_scale, mel_temp, motion_temp):
|
83 |
+
|
84 |
+
## Number of timesteps to run the reverse denoising process
|
85 |
+
n_timesteps = {
|
86 |
+
'mel': mel_timestep,
|
87 |
+
'motion': motion_timestep,
|
88 |
+
}
|
89 |
+
|
90 |
+
## Sampling temperature
|
91 |
+
temperature = {
|
92 |
+
'mel': mel_temp,
|
93 |
+
'motion': motion_temp
|
94 |
+
}
|
95 |
+
text_processed = process_text(text)
|
96 |
+
t = dt.datetime.now()
|
97 |
+
output = model.synthesise(
|
98 |
+
text_processed['x'],
|
99 |
+
text_processed['x_lengths'],
|
100 |
+
n_timesteps=n_timesteps,
|
101 |
+
temperature=temperature,
|
102 |
+
stoc=False,
|
103 |
+
spk=None,
|
104 |
+
length_scale=length_scale
|
105 |
+
)
|
106 |
+
|
107 |
+
t = (dt.datetime.now() - t).total_seconds()
|
108 |
+
print(f'RTF: {t * 22050 / (output["mel"].shape[-1] * 256)}')
|
109 |
+
|
110 |
+
output.update(text_processed) # merge everything to one dict
|
111 |
+
return output
|
112 |
+
|
113 |
+
@torch.inference_mode()
|
114 |
+
def to_waveform(mel, vocoder):
|
115 |
+
audio = vocoder(mel).clamp(-1, 1)
|
116 |
+
audio = denoiser(audio.squeeze(0)).cpu().squeeze()
|
117 |
+
return audio
|
118 |
+
|
119 |
+
|
120 |
+
def to_bvh(motion):
|
121 |
+
with warnings.catch_warnings():
|
122 |
+
warnings.simplefilter("ignore")
|
123 |
+
return motion_pipeline.inverse_transform([motion.cpu().squeeze(0).T])
|
124 |
+
|
125 |
+
|
126 |
+
def save_to_folder(filename: str, output: dict, folder: str):
|
127 |
+
folder = Path(folder)
|
128 |
+
folder.mkdir(exist_ok=True, parents=True)
|
129 |
+
np.save(folder / f'{filename}', output['mel'].cpu().numpy())
|
130 |
+
sf.write(folder / f'{filename}.wav', output['waveform'], 22050, 'PCM_24')
|
131 |
+
with open(folder / f'{filename}.bvh', 'w') as f:
|
132 |
+
bvh_writer.write(output['bvh'], f)
|
133 |
+
|
134 |
+
|
135 |
+
def to_stick_video(filename, bvh, folder):
|
136 |
+
folder = Path(folder)
|
137 |
+
folder.mkdir(exist_ok=True, parents=True)
|
138 |
+
|
139 |
+
with warnings.catch_warnings():
|
140 |
+
warnings.simplefilter("ignore")
|
141 |
+
X_pos = mocap_params.fit_transform([bvh])
|
142 |
+
print(f"rendering {filename} ...")
|
143 |
+
render_mp4(X_pos[0], folder / f'{filename}.mp4', axis_scale=200)
|
144 |
+
|
145 |
+
|
146 |
+
def combine_audio_video(filename: str, folder: str):
|
147 |
+
print("Combining audio and video")
|
148 |
+
folder = Path(folder)
|
149 |
+
folder.mkdir(exist_ok=True, parents=True)
|
150 |
+
|
151 |
+
input_video = ffmpeg.input(str(folder / f'{filename}.mp4'))
|
152 |
+
input_audio = ffmpeg.input(str(folder / f'{filename}.wav'))
|
153 |
+
output_filename = folder / f'{filename}_audio.mp4'
|
154 |
+
ffmpeg.concat(input_video, input_audio, v=1, a=1).output(str(output_filename)).run(overwrite_output=True)
|
155 |
+
print(f"Final output with audio: {output_filename}")
|
156 |
+
|
157 |
+
|
158 |
+
def run(text, output, mel_timestep, motion_timestep, length_scale, mel_temp, motion_temp):
|
159 |
+
print("Running synthesis")
|
160 |
+
output = synthesise(text, mel_timestep, motion_timestep, length_scale, mel_temp, motion_temp)
|
161 |
+
output['waveform'] = to_waveform(output['mel'], vocoder)
|
162 |
+
output['bvh'] = to_bvh(output['motion'])[0]
|
163 |
+
save_to_folder('temp', output, OUTPUT_FOLDER)
|
164 |
+
return (
|
165 |
+
output,
|
166 |
+
output['x_phones'],
|
167 |
+
plot_tensor(output['mel'].squeeze().cpu().numpy()),
|
168 |
+
plot_tensor(output['motion'].squeeze().cpu().numpy()),
|
169 |
+
str(Path(OUTPUT_FOLDER) / f'temp.wav'),
|
170 |
+
gr.update(interactive=True)
|
171 |
+
)
|
172 |
+
|
173 |
+
def visualize_it(output):
|
174 |
+
to_stick_video('temp', output['bvh'], OUTPUT_FOLDER)
|
175 |
+
combine_audio_video('temp', OUTPUT_FOLDER)
|
176 |
+
return str(Path(OUTPUT_FOLDER) / 'temp_audio.mp4')
|
177 |
+
|
178 |
+
|
179 |
+
with gr.Blocks() as demo:
|
180 |
+
|
181 |
+
output = gr.State(value=None)
|
182 |
+
|
183 |
+
with gr.Row():
|
184 |
+
gr.Markdown("# Text Input")
|
185 |
+
with gr.Row():
|
186 |
+
text = gr.Textbox(label="Text Input")
|
187 |
+
|
188 |
+
with gr.Box():
|
189 |
+
with gr.Row():
|
190 |
+
gr.Markdown("### Hyper parameters")
|
191 |
+
with gr.Row():
|
192 |
+
mel_timestep = gr.Slider(label="Number of timesteps (mel)", minimum=0, maximum=1000, step=1, value=50, interactive=True)
|
193 |
+
motion_timestep = gr.Slider(label="Number of timesteps (motion)", minimum=0, maximum=1000, step=1, value=500, interactive=True)
|
194 |
+
length_scale = gr.Slider(label="Length scale (Speaking rate)", minimum=0.01, maximum=3.0, step=0.05, value=1.15, interactive=True)
|
195 |
+
mel_temp = gr.Slider(label="Sampling temperature (mel)", minimum=0.01, maximum=5.0, step=0.05, value=1.3, interactive=True)
|
196 |
+
motion_temp = gr.Slider(label="Sampling temperature (motion)", minimum=0.01, maximum=5.0, step=0.05, value=1.5, interactive=True)
|
197 |
+
|
198 |
+
synth_btn = gr.Button("Synthesise")
|
199 |
+
|
200 |
+
with gr.Box():
|
201 |
+
with gr.Row():
|
202 |
+
gr.Markdown("### Phonetised text")
|
203 |
+
with gr.Row():
|
204 |
+
phonetised_text = gr.Textbox(label="Phonetised text", interactive=False)
|
205 |
+
|
206 |
+
with gr.Box():
|
207 |
+
with gr.Row():
|
208 |
+
mel_spectrogram = gr.Image(interactive=False, label="mel spectrogram")
|
209 |
+
motion_representation = gr.Image(interactive=False, label="Motion representation")
|
210 |
+
|
211 |
+
with gr.Row():
|
212 |
+
audio = gr.Audio(interactive=False, label="Audio")
|
213 |
+
|
214 |
+
with gr.Box():
|
215 |
+
with gr.Row():
|
216 |
+
gr.Markdown("### Generate stick figure visualisation")
|
217 |
+
with gr.Row():
|
218 |
+
gr.Markdown("(This will take a while)")
|
219 |
+
with gr.Row():
|
220 |
+
visualize = gr.Button("Visualize", interactive=False)
|
221 |
+
|
222 |
+
with gr.Row():
|
223 |
+
video = gr.Video(label="Video", interactive=False)
|
224 |
+
|
225 |
+
synth_btn.click(
|
226 |
+
fn=run,
|
227 |
+
inputs=[
|
228 |
+
text,
|
229 |
+
output,
|
230 |
+
mel_timestep,
|
231 |
+
motion_timestep,
|
232 |
+
length_scale,
|
233 |
+
mel_temp,
|
234 |
+
motion_temp
|
235 |
+
],
|
236 |
+
outputs=[
|
237 |
+
output,
|
238 |
+
phonetised_text,
|
239 |
+
mel_spectrogram,
|
240 |
+
motion_representation,
|
241 |
+
audio,
|
242 |
+
# video,
|
243 |
+
visualize
|
244 |
+
], api_name="diff_ttsg")
|
245 |
+
|
246 |
+
visualize.click(
|
247 |
+
fn=visualize_it,
|
248 |
+
inputs=[output],
|
249 |
+
outputs=[video],
|
250 |
+
)
|
251 |
+
|
252 |
+
demo.queue(1)
|
253 |
+
demo.launch()
|
diff_ttsg/__init__.py
ADDED
File without changes
|
diff_ttsg/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (152 Bytes). View file
|
|
diff_ttsg/data/__init__.py
ADDED
File without changes
|
diff_ttsg/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (157 Bytes). View file
|
|
diff_ttsg/data/__pycache__/cormac_datamodule.cpython-310.pyc
ADDED
Binary file (7.29 kB). View file
|
|
diff_ttsg/data/components/__init__.py
ADDED
File without changes
|
diff_ttsg/data/cormac_datamodule.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchaudio as ta
|
10 |
+
from einops import pack
|
11 |
+
from lightning import LightningDataModule
|
12 |
+
from torch.utils.data.dataloader import DataLoader
|
13 |
+
|
14 |
+
from diff_ttsg.text import cmudict, text_to_sequence
|
15 |
+
from diff_ttsg.text.symbols import symbols
|
16 |
+
from diff_ttsg.utils.audio import mel_spectrogram
|
17 |
+
from diff_ttsg.utils.model import fix_len_compatibility, normalize
|
18 |
+
from diff_ttsg.utils.utils import intersperse, parse_filelist
|
19 |
+
|
20 |
+
|
21 |
+
class CormacDataModule(LightningDataModule):
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
train_filelist_path,
|
26 |
+
valid_filelist_path,
|
27 |
+
batch_size,
|
28 |
+
num_workers,
|
29 |
+
pin_memory,
|
30 |
+
cmudict_path,
|
31 |
+
motion_folder,
|
32 |
+
add_blank,
|
33 |
+
n_fft,
|
34 |
+
n_feats,
|
35 |
+
sample_rate,
|
36 |
+
hop_length,
|
37 |
+
win_length,
|
38 |
+
f_min,
|
39 |
+
f_max,
|
40 |
+
data_statistics,
|
41 |
+
motion_pipeline_filename,
|
42 |
+
seed
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
# this line allows to access init params with 'self.hparams' attribute
|
47 |
+
# also ensures init params will be stored in ckpt
|
48 |
+
self.save_hyperparameters(logger=False)
|
49 |
+
|
50 |
+
def setup(self, stage: Optional[str] = None):
|
51 |
+
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
52 |
+
|
53 |
+
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
|
54 |
+
careful not to execute things like random split twice!
|
55 |
+
"""
|
56 |
+
# load and split datasets only if not loaded already
|
57 |
+
|
58 |
+
self.trainset = TextMelDataset(
|
59 |
+
self.hparams.train_filelist_path,
|
60 |
+
self.hparams.cmudict_path,
|
61 |
+
self.hparams.motion_folder,
|
62 |
+
self.hparams.add_blank,
|
63 |
+
self.hparams.n_fft,
|
64 |
+
self.hparams.n_feats,
|
65 |
+
self.hparams.sample_rate,
|
66 |
+
self.hparams.hop_length,
|
67 |
+
self.hparams.win_length,
|
68 |
+
self.hparams.f_min,
|
69 |
+
self.hparams.f_max,
|
70 |
+
self.hparams.data_statistics,
|
71 |
+
self.hparams.seed
|
72 |
+
)
|
73 |
+
self.validset = TextMelDataset(
|
74 |
+
self.hparams.valid_filelist_path,
|
75 |
+
self.hparams.cmudict_path,
|
76 |
+
self.hparams.motion_folder,
|
77 |
+
self.hparams.add_blank,
|
78 |
+
self.hparams.n_fft,
|
79 |
+
self.hparams.n_feats,
|
80 |
+
self.hparams.sample_rate,
|
81 |
+
self.hparams.hop_length,
|
82 |
+
self.hparams.win_length,
|
83 |
+
self.hparams.f_min,
|
84 |
+
self.hparams.f_max,
|
85 |
+
self.hparams.data_statistics,
|
86 |
+
self.hparams.seed
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
def train_dataloader(self):
|
91 |
+
return DataLoader(
|
92 |
+
dataset=self.trainset,
|
93 |
+
batch_size=self.hparams.batch_size,
|
94 |
+
num_workers=self.hparams.num_workers,
|
95 |
+
pin_memory=self.hparams.pin_memory,
|
96 |
+
shuffle=True,
|
97 |
+
collate_fn=TextMelBatchCollate()
|
98 |
+
)
|
99 |
+
|
100 |
+
def val_dataloader(self):
|
101 |
+
return DataLoader(
|
102 |
+
dataset=self.validset,
|
103 |
+
batch_size=self.hparams.batch_size,
|
104 |
+
num_workers=self.hparams.num_workers,
|
105 |
+
pin_memory=self.hparams.pin_memory,
|
106 |
+
shuffle=False,
|
107 |
+
collate_fn=TextMelBatchCollate()
|
108 |
+
)
|
109 |
+
|
110 |
+
def teardown(self, stage: Optional[str] = None):
|
111 |
+
"""Clean up after fit or test."""
|
112 |
+
pass
|
113 |
+
|
114 |
+
def state_dict(self):
|
115 |
+
"""Extra things to save to checkpoint."""
|
116 |
+
return {}
|
117 |
+
|
118 |
+
def load_state_dict(self, state_dict: Dict[str, Any]):
|
119 |
+
"""Things to do when loading checkpoint."""
|
120 |
+
pass
|
121 |
+
|
122 |
+
|
123 |
+
class TextMelDataset(torch.utils.data.Dataset):
|
124 |
+
def __init__(self, filelist_path, cmudict_path, motion_folder, add_blank=True,
|
125 |
+
n_fft=1024, n_mels=80, sample_rate=22050,
|
126 |
+
hop_length=256, win_length=1024, f_min=0., f_max=8000, data_parameters=None, seed=None):
|
127 |
+
self.filepaths_and_text = parse_filelist(filelist_path)
|
128 |
+
self.motion_fileloc = Path(motion_folder)
|
129 |
+
self.cmudict = cmudict.CMUDict(cmudict_path)
|
130 |
+
self.add_blank = add_blank
|
131 |
+
self.n_fft = n_fft
|
132 |
+
self.n_mels = n_mels
|
133 |
+
self.sample_rate = sample_rate
|
134 |
+
self.hop_length = hop_length
|
135 |
+
self.win_length = win_length
|
136 |
+
self.f_min = f_min
|
137 |
+
self.f_max = f_max
|
138 |
+
if data_parameters is not None:
|
139 |
+
self.data_parameters = data_parameters
|
140 |
+
else:
|
141 |
+
self.data_parameters = { 'mel_mean': 0, 'mel_std': 1, 'motion_mean': 0, 'motion_std': 1 }
|
142 |
+
random.seed(seed)
|
143 |
+
random.shuffle(self.filepaths_and_text)
|
144 |
+
|
145 |
+
def get_pair(self, filepath_and_text):
|
146 |
+
filepath, text = filepath_and_text[0], filepath_and_text[1]
|
147 |
+
text = self.get_text(text, add_blank=self.add_blank)
|
148 |
+
mel = self.get_mel(filepath)
|
149 |
+
motion = self.get_motion(filepath, mel.shape[1])
|
150 |
+
return (text, mel, motion)
|
151 |
+
|
152 |
+
def get_motion(self, filename, mel_shape, ext=".expmap_86.1328125fps.pkl"):
|
153 |
+
file_loc = self.motion_fileloc / Path(Path(filename).name).with_suffix(ext)
|
154 |
+
motion = torch.from_numpy(pd.read_pickle(file_loc).to_numpy())
|
155 |
+
motion = F.interpolate(motion.T.unsqueeze(0), mel_shape).squeeze(0)
|
156 |
+
motion = normalize(motion, self.data_parameters['motion_mean'], self.data_parameters['motion_std'])
|
157 |
+
return motion
|
158 |
+
|
159 |
+
def get_mel(self, filepath):
|
160 |
+
audio, sr = ta.load(filepath)
|
161 |
+
assert sr == self.sample_rate
|
162 |
+
mel = mel_spectrogram(audio, self.n_fft, 80, self.sample_rate, self.hop_length,
|
163 |
+
self.win_length, self.f_min, self.f_max, center=False).squeeze()
|
164 |
+
mel = normalize(mel, self.data_parameters['mel_mean'], self.data_parameters['mel_std'])
|
165 |
+
return mel
|
166 |
+
|
167 |
+
def get_text(self, text, add_blank=True):
|
168 |
+
text_norm = text_to_sequence(text, dictionary=self.cmudict)
|
169 |
+
if self.add_blank:
|
170 |
+
text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols)
|
171 |
+
text_norm = torch.IntTensor(text_norm)
|
172 |
+
return text_norm
|
173 |
+
|
174 |
+
def __getitem__(self, index):
|
175 |
+
text, mel, motion = self.get_pair(self.filepaths_and_text[index])
|
176 |
+
item = {'y': mel, 'x': text, 'y_motion': motion}
|
177 |
+
return item
|
178 |
+
|
179 |
+
def __len__(self):
|
180 |
+
return len(self.filepaths_and_text)
|
181 |
+
|
182 |
+
def sample_test_batch(self, size):
|
183 |
+
idx = np.random.choice(range(len(self)), size=size, replace=False)
|
184 |
+
test_batch = []
|
185 |
+
for index in idx:
|
186 |
+
test_batch.append(self.__getitem__(index))
|
187 |
+
return test_batch
|
188 |
+
|
189 |
+
|
190 |
+
class TextMelBatchCollate(object):
|
191 |
+
def __call__(self, batch):
|
192 |
+
B = len(batch)
|
193 |
+
y_max_length = max([item['y'].shape[-1] for item in batch])
|
194 |
+
y_max_length = fix_len_compatibility(y_max_length)
|
195 |
+
x_max_length = max([item['x'].shape[-1] for item in batch])
|
196 |
+
n_feats = batch[0]['y'].shape[-2]
|
197 |
+
n_motion = batch[0]['y_motion'].shape[-2]
|
198 |
+
|
199 |
+
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
|
200 |
+
x = torch.zeros((B, x_max_length), dtype=torch.long)
|
201 |
+
y_motion = torch.zeros((B, n_motion, y_max_length), dtype=torch.float32)
|
202 |
+
y_lengths, x_lengths = [], []
|
203 |
+
|
204 |
+
for i, item in enumerate(batch):
|
205 |
+
y_, x_, y_motion_ = item['y'], item['x'], item['y_motion']
|
206 |
+
y_lengths.append(y_.shape[-1])
|
207 |
+
x_lengths.append(x_.shape[-1])
|
208 |
+
y[i, :, :y_.shape[-1]] = y_
|
209 |
+
x[i, :x_.shape[-1]] = x_
|
210 |
+
y_motion[i, :, :y_motion_.shape[-1]] = y_motion_
|
211 |
+
|
212 |
+
y_lengths = torch.LongTensor(y_lengths)
|
213 |
+
x_lengths = torch.LongTensor(x_lengths)
|
214 |
+
return {'x': x, 'x_lengths': x_lengths, 'y': y, 'y_lengths': y_lengths, 'y_motion': y_motion}
|
diff_ttsg/data/mnist_datamodule.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from lightning import LightningDataModule
|
5 |
+
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
6 |
+
from torchvision.datasets import MNIST
|
7 |
+
from torchvision.transforms import transforms
|
8 |
+
|
9 |
+
|
10 |
+
class MNISTDataModule(LightningDataModule):
|
11 |
+
"""Example of LightningDataModule for MNIST dataset.
|
12 |
+
|
13 |
+
A DataModule implements 6 key methods:
|
14 |
+
def prepare_data(self):
|
15 |
+
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
|
16 |
+
# download data, pre-process, split, save to disk, etc...
|
17 |
+
def setup(self, stage):
|
18 |
+
# things to do on every process in DDP
|
19 |
+
# load data, set variables, etc...
|
20 |
+
def train_dataloader(self):
|
21 |
+
# return train dataloader
|
22 |
+
def val_dataloader(self):
|
23 |
+
# return validation dataloader
|
24 |
+
def test_dataloader(self):
|
25 |
+
# return test dataloader
|
26 |
+
def teardown(self):
|
27 |
+
# called on every process in DDP
|
28 |
+
# clean up after fit or test
|
29 |
+
|
30 |
+
This allows you to share a full dataset without explaining how to download,
|
31 |
+
split, transform and process the data.
|
32 |
+
|
33 |
+
Read the docs:
|
34 |
+
https://lightning.ai/docs/pytorch/latest/data/datamodule.html
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
data_dir: str = "data/",
|
40 |
+
train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000),
|
41 |
+
batch_size: int = 64,
|
42 |
+
num_workers: int = 0,
|
43 |
+
pin_memory: bool = False,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
# this line allows to access init params with 'self.hparams' attribute
|
48 |
+
# also ensures init params will be stored in ckpt
|
49 |
+
self.save_hyperparameters(logger=False)
|
50 |
+
|
51 |
+
# data transformations
|
52 |
+
self.transforms = transforms.Compose(
|
53 |
+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
54 |
+
)
|
55 |
+
|
56 |
+
self.data_train: Optional[Dataset] = None
|
57 |
+
self.data_val: Optional[Dataset] = None
|
58 |
+
self.data_test: Optional[Dataset] = None
|
59 |
+
|
60 |
+
@property
|
61 |
+
def num_classes(self):
|
62 |
+
return 10
|
63 |
+
|
64 |
+
def prepare_data(self):
|
65 |
+
"""Download data if needed.
|
66 |
+
|
67 |
+
Do not use it to assign state (self.x = y).
|
68 |
+
"""
|
69 |
+
MNIST(self.hparams.data_dir, train=True, download=True)
|
70 |
+
MNIST(self.hparams.data_dir, train=False, download=True)
|
71 |
+
|
72 |
+
def setup(self, stage: Optional[str] = None):
|
73 |
+
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
74 |
+
|
75 |
+
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
|
76 |
+
careful not to execute things like random split twice!
|
77 |
+
"""
|
78 |
+
# load and split datasets only if not loaded already
|
79 |
+
if not self.data_train and not self.data_val and not self.data_test:
|
80 |
+
trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)
|
81 |
+
testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms)
|
82 |
+
dataset = ConcatDataset(datasets=[trainset, testset])
|
83 |
+
self.data_train, self.data_val, self.data_test = random_split(
|
84 |
+
dataset=dataset,
|
85 |
+
lengths=self.hparams.train_val_test_split,
|
86 |
+
generator=torch.Generator().manual_seed(42),
|
87 |
+
)
|
88 |
+
|
89 |
+
def train_dataloader(self):
|
90 |
+
return DataLoader(
|
91 |
+
dataset=self.data_train,
|
92 |
+
batch_size=self.hparams.batch_size,
|
93 |
+
num_workers=self.hparams.num_workers,
|
94 |
+
pin_memory=self.hparams.pin_memory,
|
95 |
+
shuffle=True,
|
96 |
+
)
|
97 |
+
|
98 |
+
def val_dataloader(self):
|
99 |
+
return DataLoader(
|
100 |
+
dataset=self.data_val,
|
101 |
+
batch_size=self.hparams.batch_size,
|
102 |
+
num_workers=self.hparams.num_workers,
|
103 |
+
pin_memory=self.hparams.pin_memory,
|
104 |
+
shuffle=False,
|
105 |
+
)
|
106 |
+
|
107 |
+
def test_dataloader(self):
|
108 |
+
return DataLoader(
|
109 |
+
dataset=self.data_test,
|
110 |
+
batch_size=self.hparams.batch_size,
|
111 |
+
num_workers=self.hparams.num_workers,
|
112 |
+
pin_memory=self.hparams.pin_memory,
|
113 |
+
shuffle=False,
|
114 |
+
)
|
115 |
+
|
116 |
+
def teardown(self, stage: Optional[str] = None):
|
117 |
+
"""Clean up after fit or test."""
|
118 |
+
pass
|
119 |
+
|
120 |
+
def state_dict(self):
|
121 |
+
"""Extra things to save to checkpoint."""
|
122 |
+
return {}
|
123 |
+
|
124 |
+
def load_state_dict(self, state_dict: Dict[str, Any]):
|
125 |
+
"""Things to do when loading checkpoint."""
|
126 |
+
pass
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
_ = MNISTDataModule()
|
diff_ttsg/eval.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import hydra
|
4 |
+
import pyrootutils
|
5 |
+
from lightning import LightningDataModule, LightningModule, Trainer
|
6 |
+
from lightning.pytorch.loggers import Logger
|
7 |
+
from omegaconf import DictConfig
|
8 |
+
|
9 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
10 |
+
# ------------------------------------------------------------------------------------ #
|
11 |
+
# the setup_root above is equivalent to:
|
12 |
+
# - adding project root dir to PYTHONPATH
|
13 |
+
# (so you don't need to force user to install project as a package)
|
14 |
+
# (necessary before importing any local modules e.g. `from src import utils`)
|
15 |
+
# - setting up PROJECT_ROOT environment variable
|
16 |
+
# (which is used as a base for paths in "configs/paths/default.yaml")
|
17 |
+
# (this way all filepaths are the same no matter where you run the code)
|
18 |
+
# - loading environment variables from ".env" in root dir
|
19 |
+
#
|
20 |
+
# you can remove it if you:
|
21 |
+
# 1. either install project as a package or move entry files to project root dir
|
22 |
+
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
23 |
+
#
|
24 |
+
# more info: https://github.com/ashleve/pyrootutils
|
25 |
+
# ------------------------------------------------------------------------------------ #
|
26 |
+
|
27 |
+
from diff_ttsg import utils
|
28 |
+
|
29 |
+
log = utils.get_pylogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
@utils.task_wrapper
|
33 |
+
def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
|
34 |
+
"""Evaluates given checkpoint on a datamodule testset.
|
35 |
+
|
36 |
+
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
37 |
+
failure. Useful for multiruns, saving info about the crash, etc.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
cfg (DictConfig): Configuration composed by Hydra.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
|
44 |
+
"""
|
45 |
+
|
46 |
+
assert cfg.ckpt_path
|
47 |
+
|
48 |
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
49 |
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
50 |
+
|
51 |
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
52 |
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
53 |
+
|
54 |
+
log.info("Instantiating loggers...")
|
55 |
+
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
|
56 |
+
|
57 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
58 |
+
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
|
59 |
+
|
60 |
+
object_dict = {
|
61 |
+
"cfg": cfg,
|
62 |
+
"datamodule": datamodule,
|
63 |
+
"model": model,
|
64 |
+
"logger": logger,
|
65 |
+
"trainer": trainer,
|
66 |
+
}
|
67 |
+
|
68 |
+
if logger:
|
69 |
+
log.info("Logging hyperparameters!")
|
70 |
+
utils.log_hyperparameters(object_dict)
|
71 |
+
|
72 |
+
log.info("Starting testing!")
|
73 |
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
|
74 |
+
|
75 |
+
# for predictions use trainer.predict(...)
|
76 |
+
# predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
|
77 |
+
|
78 |
+
metric_dict = trainer.callback_metrics
|
79 |
+
|
80 |
+
return metric_dict, object_dict
|
81 |
+
|
82 |
+
|
83 |
+
@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
|
84 |
+
def main(cfg: DictConfig) -> None:
|
85 |
+
# apply extra utilities
|
86 |
+
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
87 |
+
utils.extras(cfg)
|
88 |
+
|
89 |
+
evaluate(cfg)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
main()
|
diff_ttsg/hifigan/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Jungil Kong
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
diff_ttsg/hifigan/README.md
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
|
2 |
+
|
3 |
+
### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
|
4 |
+
|
5 |
+
In our [paper](https://arxiv.org/abs/2010.05646),
|
6 |
+
we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.<br/>
|
7 |
+
We provide our implementation and pretrained models as open source in this repository.
|
8 |
+
|
9 |
+
**Abstract :**
|
10 |
+
Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
|
11 |
+
Although such methods improve the sampling efficiency and memory usage,
|
12 |
+
their sample quality has not yet reached that of autoregressive and flow-based generative models.
|
13 |
+
In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
|
14 |
+
As speech audio consists of sinusoidal signals with various periods,
|
15 |
+
we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
|
16 |
+
A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
|
17 |
+
demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
|
18 |
+
real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
|
19 |
+
speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
|
20 |
+
faster than real-time on CPU with comparable quality to an autoregressive counterpart.
|
21 |
+
|
22 |
+
Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
|
23 |
+
|
24 |
+
|
25 |
+
## Pre-requisites
|
26 |
+
1. Python >= 3.6
|
27 |
+
2. Clone this repository.
|
28 |
+
3. Install python requirements. Please refer [requirements.txt](requirements.txt)
|
29 |
+
4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
|
30 |
+
And move all wav files to `LJSpeech-1.1/wavs`
|
31 |
+
|
32 |
+
|
33 |
+
## Training
|
34 |
+
```
|
35 |
+
python train.py --config config_v1.json
|
36 |
+
```
|
37 |
+
To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.<br>
|
38 |
+
Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.<br>
|
39 |
+
You can change the path by adding `--checkpoint_path` option.
|
40 |
+
|
41 |
+
Validation loss during training with V1 generator.<br>
|
42 |
+
![validation loss](./validation_loss.png)
|
43 |
+
|
44 |
+
## Pretrained Model
|
45 |
+
You can also use pretrained models we provide.<br/>
|
46 |
+
[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)<br/>
|
47 |
+
Details of each folder are as in follows:
|
48 |
+
|
49 |
+
|Folder Name|Generator|Dataset|Fine-Tuned|
|
50 |
+
|------|---|---|---|
|
51 |
+
|LJ_V1|V1|LJSpeech|No|
|
52 |
+
|LJ_V2|V2|LJSpeech|No|
|
53 |
+
|LJ_V3|V3|LJSpeech|No|
|
54 |
+
|LJ_FT_T2_V1|V1|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
|
55 |
+
|LJ_FT_T2_V2|V2|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
|
56 |
+
|LJ_FT_T2_V3|V3|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
|
57 |
+
|VCTK_V1|V1|VCTK|No|
|
58 |
+
|VCTK_V2|V2|VCTK|No|
|
59 |
+
|VCTK_V3|V3|VCTK|No|
|
60 |
+
|UNIVERSAL_V1|V1|Universal|No|
|
61 |
+
|
62 |
+
We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
|
63 |
+
|
64 |
+
## Fine-Tuning
|
65 |
+
1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.<br/>
|
66 |
+
The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.<br/>
|
67 |
+
Example:
|
68 |
+
```
|
69 |
+
Audio File : LJ001-0001.wav
|
70 |
+
Mel-Spectrogram File : LJ001-0001.npy
|
71 |
+
```
|
72 |
+
2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.<br/>
|
73 |
+
3. Run the following command.
|
74 |
+
```
|
75 |
+
python train.py --fine_tuning True --config config_v1.json
|
76 |
+
```
|
77 |
+
For other command line options, please refer to the training section.
|
78 |
+
|
79 |
+
|
80 |
+
## Inference from wav file
|
81 |
+
1. Make `test_files` directory and copy wav files into the directory.
|
82 |
+
2. Run the following command.
|
83 |
+
```
|
84 |
+
python inference.py --checkpoint_file [generator checkpoint file path]
|
85 |
+
```
|
86 |
+
Generated wav files are saved in `generated_files` by default.<br>
|
87 |
+
You can change the path by adding `--output_dir` option.
|
88 |
+
|
89 |
+
|
90 |
+
## Inference for end-to-end speech synthesis
|
91 |
+
1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.<br>
|
92 |
+
You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
|
93 |
+
[Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
|
94 |
+
2. Run the following command.
|
95 |
+
```
|
96 |
+
python inference_e2e.py --checkpoint_file [generator checkpoint file path]
|
97 |
+
```
|
98 |
+
Generated wav files are saved in `generated_files_from_mel` by default.<br>
|
99 |
+
You can change the path by adding `--output_dir` option.
|
100 |
+
|
101 |
+
|
102 |
+
## Acknowledgements
|
103 |
+
We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
|
104 |
+
and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.
|
105 |
+
|
diff_ttsg/hifigan/__init__.py
ADDED
File without changes
|
diff_ttsg/hifigan/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (160 Bytes). View file
|
|
diff_ttsg/hifigan/__pycache__/config.cpython-310.pyc
ADDED
Binary file (1.02 kB). View file
|
|
diff_ttsg/hifigan/__pycache__/denoiser.cpython-310.pyc
ADDED
Binary file (2.56 kB). View file
|
|
diff_ttsg/hifigan/__pycache__/env.cpython-310.pyc
ADDED
Binary file (883 Bytes). View file
|
|
diff_ttsg/hifigan/__pycache__/models.cpython-310.pyc
ADDED
Binary file (8.73 kB). View file
|
|
diff_ttsg/hifigan/__pycache__/xutils.cpython-310.pyc
ADDED
Binary file (2.1 kB). View file
|
|
diff_ttsg/hifigan/config.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
v1 = {
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 16,
|
5 |
+
"learning_rate": 0.0004,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.999,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [8,8,2,2],
|
12 |
+
"upsample_kernel_sizes": [16,16,4,4],
|
13 |
+
"upsample_initial_channel": 512,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
"resblock_initial_channel": 256,
|
17 |
+
|
18 |
+
"segment_size": 8192,
|
19 |
+
"num_mels": 80,
|
20 |
+
"num_freq": 1025,
|
21 |
+
"n_fft": 1024,
|
22 |
+
"hop_size": 256,
|
23 |
+
"win_size": 1024,
|
24 |
+
|
25 |
+
"sampling_rate": 22050,
|
26 |
+
|
27 |
+
"fmin": 0,
|
28 |
+
"fmax": 8000,
|
29 |
+
"fmax_loss": None,
|
30 |
+
|
31 |
+
"num_workers": 4,
|
32 |
+
|
33 |
+
"dist_config": {
|
34 |
+
"dist_backend": "nccl",
|
35 |
+
"dist_url": "tcp://localhost:54321",
|
36 |
+
"world_size": 1
|
37 |
+
}
|
38 |
+
}
|
diff_ttsg/hifigan/denoiser.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
|
2 |
+
|
3 |
+
"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class Denoiser(torch.nn.Module):
|
8 |
+
"""Removes model bias from audio produced with waveglow"""
|
9 |
+
|
10 |
+
def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
|
11 |
+
super().__init__()
|
12 |
+
self.filter_length = filter_length
|
13 |
+
self.hop_length = int(filter_length / n_overlap)
|
14 |
+
self.win_length = win_length
|
15 |
+
|
16 |
+
dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
|
17 |
+
self.device = device
|
18 |
+
if mode == "zeros":
|
19 |
+
mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
|
20 |
+
elif mode == "normal":
|
21 |
+
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
|
22 |
+
else:
|
23 |
+
raise Exception(f"Mode {mode} if not supported")
|
24 |
+
|
25 |
+
def stft_fn(audio, n_fft, hop_length, win_length, window):
|
26 |
+
spec = torch.stft(
|
27 |
+
audio,
|
28 |
+
n_fft=n_fft,
|
29 |
+
hop_length=hop_length,
|
30 |
+
win_length=win_length,
|
31 |
+
window=window,
|
32 |
+
return_complex=True,
|
33 |
+
)
|
34 |
+
spec = torch.view_as_real(spec)
|
35 |
+
return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
|
36 |
+
|
37 |
+
self.stft = lambda x : stft_fn(
|
38 |
+
audio=x,
|
39 |
+
n_fft=self.filter_length,
|
40 |
+
hop_length=self.hop_length,
|
41 |
+
win_length=self.win_length,
|
42 |
+
window=torch.hann_window(self.win_length, device=device)
|
43 |
+
)
|
44 |
+
self.istft = lambda x, y: torch.istft(
|
45 |
+
torch.complex(x * torch.cos(y), x * torch.sin(y)),
|
46 |
+
n_fft=self.filter_length,
|
47 |
+
hop_length=self.hop_length,
|
48 |
+
win_length=self.win_length,
|
49 |
+
window=torch.hann_window(self.win_length, device=device),
|
50 |
+
)
|
51 |
+
|
52 |
+
with torch.no_grad():
|
53 |
+
bias_audio = vocoder(mel_input).float().squeeze(0)
|
54 |
+
bias_spec, _ = self.stft(bias_audio)
|
55 |
+
|
56 |
+
self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
|
57 |
+
|
58 |
+
@torch.inference_mode()
|
59 |
+
def forward(self, audio, strength=0.0005):
|
60 |
+
audio_spec, audio_angles = self.stft(audio)
|
61 |
+
audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength
|
62 |
+
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
|
63 |
+
audio_denoised = self.istft(audio_spec_denoised, audio_angles)
|
64 |
+
return audio_denoised
|
diff_ttsg/hifigan/env.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/jik876/hifi-gan """
|
2 |
+
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
|
7 |
+
class AttrDict(dict):
|
8 |
+
def __init__(self, *args, **kwargs):
|
9 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
10 |
+
self.__dict__ = self
|
11 |
+
|
12 |
+
|
13 |
+
def build_env(config, config_name, path):
|
14 |
+
t_path = os.path.join(path, config_name)
|
15 |
+
if config != t_path:
|
16 |
+
os.makedirs(path, exist_ok=True)
|
17 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
diff_ttsg/hifigan/meldataset.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/jik876/hifi-gan """
|
2 |
+
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.utils.data
|
10 |
+
from librosa.filters import mel as librosa_mel_fn
|
11 |
+
from librosa.util import normalize
|
12 |
+
from scipy.io.wavfile import read
|
13 |
+
|
14 |
+
MAX_WAV_VALUE = 32768.0
|
15 |
+
|
16 |
+
|
17 |
+
def load_wav(full_path):
|
18 |
+
sampling_rate, data = read(full_path)
|
19 |
+
return data, sampling_rate
|
20 |
+
|
21 |
+
|
22 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
23 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
24 |
+
|
25 |
+
|
26 |
+
def dynamic_range_decompression(x, C=1):
|
27 |
+
return np.exp(x) / C
|
28 |
+
|
29 |
+
|
30 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
31 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
32 |
+
|
33 |
+
|
34 |
+
def dynamic_range_decompression_torch(x, C=1):
|
35 |
+
return torch.exp(x) / C
|
36 |
+
|
37 |
+
|
38 |
+
def spectral_normalize_torch(magnitudes):
|
39 |
+
output = dynamic_range_compression_torch(magnitudes)
|
40 |
+
return output
|
41 |
+
|
42 |
+
|
43 |
+
def spectral_de_normalize_torch(magnitudes):
|
44 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
45 |
+
return output
|
46 |
+
|
47 |
+
|
48 |
+
mel_basis = {}
|
49 |
+
hann_window = {}
|
50 |
+
|
51 |
+
|
52 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
53 |
+
if torch.min(y) < -1.:
|
54 |
+
print('min value is ', torch.min(y))
|
55 |
+
if torch.max(y) > 1.:
|
56 |
+
print('max value is ', torch.max(y))
|
57 |
+
|
58 |
+
global mel_basis, hann_window
|
59 |
+
if fmax not in mel_basis:
|
60 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
61 |
+
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
62 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
63 |
+
|
64 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
65 |
+
y = y.squeeze(1)
|
66 |
+
|
67 |
+
spec = torch.view_as_real(torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
68 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True))
|
69 |
+
|
70 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
71 |
+
|
72 |
+
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
73 |
+
spec = spectral_normalize_torch(spec)
|
74 |
+
|
75 |
+
return spec
|
76 |
+
|
77 |
+
|
78 |
+
def get_dataset_filelist(a):
|
79 |
+
with open(a.input_training_file, 'r', encoding='utf-8') as fi:
|
80 |
+
training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
|
81 |
+
for x in fi.read().split('\n') if len(x) > 0]
|
82 |
+
|
83 |
+
with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
|
84 |
+
validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
|
85 |
+
for x in fi.read().split('\n') if len(x) > 0]
|
86 |
+
return training_files, validation_files
|
87 |
+
|
88 |
+
|
89 |
+
class MelDataset(torch.utils.data.Dataset):
|
90 |
+
def __init__(self, training_files, segment_size, n_fft, num_mels,
|
91 |
+
hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
|
92 |
+
device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
|
93 |
+
self.audio_files = training_files
|
94 |
+
random.seed(1234)
|
95 |
+
if shuffle:
|
96 |
+
random.shuffle(self.audio_files)
|
97 |
+
self.segment_size = segment_size
|
98 |
+
self.sampling_rate = sampling_rate
|
99 |
+
self.split = split
|
100 |
+
self.n_fft = n_fft
|
101 |
+
self.num_mels = num_mels
|
102 |
+
self.hop_size = hop_size
|
103 |
+
self.win_size = win_size
|
104 |
+
self.fmin = fmin
|
105 |
+
self.fmax = fmax
|
106 |
+
self.fmax_loss = fmax_loss
|
107 |
+
self.cached_wav = None
|
108 |
+
self.n_cache_reuse = n_cache_reuse
|
109 |
+
self._cache_ref_count = 0
|
110 |
+
self.device = device
|
111 |
+
self.fine_tuning = fine_tuning
|
112 |
+
self.base_mels_path = base_mels_path
|
113 |
+
|
114 |
+
def __getitem__(self, index):
|
115 |
+
filename = self.audio_files[index]
|
116 |
+
if self._cache_ref_count == 0:
|
117 |
+
audio, sampling_rate = load_wav(filename)
|
118 |
+
audio = audio / MAX_WAV_VALUE
|
119 |
+
if not self.fine_tuning:
|
120 |
+
audio = normalize(audio) * 0.95
|
121 |
+
self.cached_wav = audio
|
122 |
+
if sampling_rate != self.sampling_rate:
|
123 |
+
raise ValueError("{} SR doesn't match target {} SR".format(
|
124 |
+
sampling_rate, self.sampling_rate))
|
125 |
+
self._cache_ref_count = self.n_cache_reuse
|
126 |
+
else:
|
127 |
+
audio = self.cached_wav
|
128 |
+
self._cache_ref_count -= 1
|
129 |
+
|
130 |
+
audio = torch.FloatTensor(audio)
|
131 |
+
audio = audio.unsqueeze(0)
|
132 |
+
|
133 |
+
if not self.fine_tuning:
|
134 |
+
if self.split:
|
135 |
+
if audio.size(1) >= self.segment_size:
|
136 |
+
max_audio_start = audio.size(1) - self.segment_size
|
137 |
+
audio_start = random.randint(0, max_audio_start)
|
138 |
+
audio = audio[:, audio_start:audio_start+self.segment_size]
|
139 |
+
else:
|
140 |
+
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
141 |
+
|
142 |
+
mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
143 |
+
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
|
144 |
+
center=False)
|
145 |
+
else:
|
146 |
+
mel = np.load(
|
147 |
+
os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
|
148 |
+
mel = torch.from_numpy(mel)
|
149 |
+
|
150 |
+
if len(mel.shape) < 3:
|
151 |
+
mel = mel.unsqueeze(0)
|
152 |
+
|
153 |
+
if self.split:
|
154 |
+
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
155 |
+
|
156 |
+
if audio.size(1) >= self.segment_size:
|
157 |
+
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
158 |
+
mel = mel[:, :, mel_start:mel_start + frames_per_seg]
|
159 |
+
audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
|
160 |
+
else:
|
161 |
+
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
|
162 |
+
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
163 |
+
|
164 |
+
mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
165 |
+
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
|
166 |
+
center=False)
|
167 |
+
|
168 |
+
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
169 |
+
|
170 |
+
def __len__(self):
|
171 |
+
return len(self.audio_files)
|
diff_ttsg/hifigan/models.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/jik876/hifi-gan """
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
7 |
+
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
8 |
+
|
9 |
+
from .xutils import get_padding, init_weights
|
10 |
+
|
11 |
+
LRELU_SLOPE = 0.1
|
12 |
+
|
13 |
+
|
14 |
+
class ResBlock1(torch.nn.Module):
|
15 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
16 |
+
super(ResBlock1, self).__init__()
|
17 |
+
self.h = h
|
18 |
+
self.convs1 = nn.ModuleList([
|
19 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
20 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
21 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
22 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
23 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
24 |
+
padding=get_padding(kernel_size, dilation[2])))
|
25 |
+
])
|
26 |
+
self.convs1.apply(init_weights)
|
27 |
+
|
28 |
+
self.convs2 = nn.ModuleList([
|
29 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
30 |
+
padding=get_padding(kernel_size, 1))),
|
31 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
32 |
+
padding=get_padding(kernel_size, 1))),
|
33 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
34 |
+
padding=get_padding(kernel_size, 1)))
|
35 |
+
])
|
36 |
+
self.convs2.apply(init_weights)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
40 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
41 |
+
xt = c1(xt)
|
42 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
43 |
+
xt = c2(xt)
|
44 |
+
x = xt + x
|
45 |
+
return x
|
46 |
+
|
47 |
+
def remove_weight_norm(self):
|
48 |
+
for l in self.convs1:
|
49 |
+
remove_weight_norm(l)
|
50 |
+
for l in self.convs2:
|
51 |
+
remove_weight_norm(l)
|
52 |
+
|
53 |
+
|
54 |
+
class ResBlock2(torch.nn.Module):
|
55 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
56 |
+
super(ResBlock2, self).__init__()
|
57 |
+
self.h = h
|
58 |
+
self.convs = nn.ModuleList([
|
59 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
60 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
61 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
62 |
+
padding=get_padding(kernel_size, dilation[1])))
|
63 |
+
])
|
64 |
+
self.convs.apply(init_weights)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
for c in self.convs:
|
68 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
69 |
+
xt = c(xt)
|
70 |
+
x = xt + x
|
71 |
+
return x
|
72 |
+
|
73 |
+
def remove_weight_norm(self):
|
74 |
+
for l in self.convs:
|
75 |
+
remove_weight_norm(l)
|
76 |
+
|
77 |
+
|
78 |
+
class Generator(torch.nn.Module):
|
79 |
+
def __init__(self, h):
|
80 |
+
super(Generator, self).__init__()
|
81 |
+
self.h = h
|
82 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
83 |
+
self.num_upsamples = len(h.upsample_rates)
|
84 |
+
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
|
85 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
86 |
+
|
87 |
+
self.ups = nn.ModuleList()
|
88 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
89 |
+
self.ups.append(weight_norm(
|
90 |
+
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
91 |
+
k, u, padding=(k-u)//2)))
|
92 |
+
|
93 |
+
self.resblocks = nn.ModuleList()
|
94 |
+
for i in range(len(self.ups)):
|
95 |
+
ch = h.upsample_initial_channel//(2**(i+1))
|
96 |
+
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
97 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
98 |
+
|
99 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
100 |
+
self.ups.apply(init_weights)
|
101 |
+
self.conv_post.apply(init_weights)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
x = self.conv_pre(x)
|
105 |
+
for i in range(self.num_upsamples):
|
106 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
107 |
+
x = self.ups[i](x)
|
108 |
+
xs = None
|
109 |
+
for j in range(self.num_kernels):
|
110 |
+
if xs is None:
|
111 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
112 |
+
else:
|
113 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
114 |
+
x = xs / self.num_kernels
|
115 |
+
x = F.leaky_relu(x)
|
116 |
+
x = self.conv_post(x)
|
117 |
+
x = torch.tanh(x)
|
118 |
+
|
119 |
+
return x
|
120 |
+
|
121 |
+
def remove_weight_norm(self):
|
122 |
+
print('Removing weight norm...')
|
123 |
+
for l in self.ups:
|
124 |
+
remove_weight_norm(l)
|
125 |
+
for l in self.resblocks:
|
126 |
+
l.remove_weight_norm()
|
127 |
+
remove_weight_norm(self.conv_pre)
|
128 |
+
remove_weight_norm(self.conv_post)
|
129 |
+
|
130 |
+
|
131 |
+
class DiscriminatorP(torch.nn.Module):
|
132 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
133 |
+
super(DiscriminatorP, self).__init__()
|
134 |
+
self.period = period
|
135 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
136 |
+
self.convs = nn.ModuleList([
|
137 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
138 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
139 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
140 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
141 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
142 |
+
])
|
143 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
fmap = []
|
147 |
+
|
148 |
+
# 1d to 2d
|
149 |
+
b, c, t = x.shape
|
150 |
+
if t % self.period != 0: # pad first
|
151 |
+
n_pad = self.period - (t % self.period)
|
152 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
153 |
+
t = t + n_pad
|
154 |
+
x = x.view(b, c, t // self.period, self.period)
|
155 |
+
|
156 |
+
for l in self.convs:
|
157 |
+
x = l(x)
|
158 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
159 |
+
fmap.append(x)
|
160 |
+
x = self.conv_post(x)
|
161 |
+
fmap.append(x)
|
162 |
+
x = torch.flatten(x, 1, -1)
|
163 |
+
|
164 |
+
return x, fmap
|
165 |
+
|
166 |
+
|
167 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
168 |
+
def __init__(self):
|
169 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
170 |
+
self.discriminators = nn.ModuleList([
|
171 |
+
DiscriminatorP(2),
|
172 |
+
DiscriminatorP(3),
|
173 |
+
DiscriminatorP(5),
|
174 |
+
DiscriminatorP(7),
|
175 |
+
DiscriminatorP(11),
|
176 |
+
])
|
177 |
+
|
178 |
+
def forward(self, y, y_hat):
|
179 |
+
y_d_rs = []
|
180 |
+
y_d_gs = []
|
181 |
+
fmap_rs = []
|
182 |
+
fmap_gs = []
|
183 |
+
for i, d in enumerate(self.discriminators):
|
184 |
+
y_d_r, fmap_r = d(y)
|
185 |
+
y_d_g, fmap_g = d(y_hat)
|
186 |
+
y_d_rs.append(y_d_r)
|
187 |
+
fmap_rs.append(fmap_r)
|
188 |
+
y_d_gs.append(y_d_g)
|
189 |
+
fmap_gs.append(fmap_g)
|
190 |
+
|
191 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
192 |
+
|
193 |
+
|
194 |
+
class DiscriminatorS(torch.nn.Module):
|
195 |
+
def __init__(self, use_spectral_norm=False):
|
196 |
+
super(DiscriminatorS, self).__init__()
|
197 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
198 |
+
self.convs = nn.ModuleList([
|
199 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
200 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
201 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
202 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
203 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
204 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
205 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
206 |
+
])
|
207 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
fmap = []
|
211 |
+
for l in self.convs:
|
212 |
+
x = l(x)
|
213 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
214 |
+
fmap.append(x)
|
215 |
+
x = self.conv_post(x)
|
216 |
+
fmap.append(x)
|
217 |
+
x = torch.flatten(x, 1, -1)
|
218 |
+
|
219 |
+
return x, fmap
|
220 |
+
|
221 |
+
|
222 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
223 |
+
def __init__(self):
|
224 |
+
super(MultiScaleDiscriminator, self).__init__()
|
225 |
+
self.discriminators = nn.ModuleList([
|
226 |
+
DiscriminatorS(use_spectral_norm=True),
|
227 |
+
DiscriminatorS(),
|
228 |
+
DiscriminatorS(),
|
229 |
+
])
|
230 |
+
self.meanpools = nn.ModuleList([
|
231 |
+
AvgPool1d(4, 2, padding=2),
|
232 |
+
AvgPool1d(4, 2, padding=2)
|
233 |
+
])
|
234 |
+
|
235 |
+
def forward(self, y, y_hat):
|
236 |
+
y_d_rs = []
|
237 |
+
y_d_gs = []
|
238 |
+
fmap_rs = []
|
239 |
+
fmap_gs = []
|
240 |
+
for i, d in enumerate(self.discriminators):
|
241 |
+
if i != 0:
|
242 |
+
y = self.meanpools[i-1](y)
|
243 |
+
y_hat = self.meanpools[i-1](y_hat)
|
244 |
+
y_d_r, fmap_r = d(y)
|
245 |
+
y_d_g, fmap_g = d(y_hat)
|
246 |
+
y_d_rs.append(y_d_r)
|
247 |
+
fmap_rs.append(fmap_r)
|
248 |
+
y_d_gs.append(y_d_g)
|
249 |
+
fmap_gs.append(fmap_g)
|
250 |
+
|
251 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
252 |
+
|
253 |
+
|
254 |
+
def feature_loss(fmap_r, fmap_g):
|
255 |
+
loss = 0
|
256 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
257 |
+
for rl, gl in zip(dr, dg):
|
258 |
+
loss += torch.mean(torch.abs(rl - gl))
|
259 |
+
|
260 |
+
return loss*2
|
261 |
+
|
262 |
+
|
263 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
264 |
+
loss = 0
|
265 |
+
r_losses = []
|
266 |
+
g_losses = []
|
267 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
268 |
+
r_loss = torch.mean((1-dr)**2)
|
269 |
+
g_loss = torch.mean(dg**2)
|
270 |
+
loss += (r_loss + g_loss)
|
271 |
+
r_losses.append(r_loss.item())
|
272 |
+
g_losses.append(g_loss.item())
|
273 |
+
|
274 |
+
return loss, r_losses, g_losses
|
275 |
+
|
276 |
+
|
277 |
+
def generator_loss(disc_outputs):
|
278 |
+
loss = 0
|
279 |
+
gen_losses = []
|
280 |
+
for dg in disc_outputs:
|
281 |
+
l = torch.mean((1-dg)**2)
|
282 |
+
gen_losses.append(l)
|
283 |
+
loss += l
|
284 |
+
|
285 |
+
return loss, gen_losses
|
286 |
+
|
diff_ttsg/hifigan/xutils.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/jik876/hifi-gan """
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import matplotlib
|
6 |
+
import torch
|
7 |
+
from torch.nn.utils import weight_norm
|
8 |
+
matplotlib.use("Agg")
|
9 |
+
import matplotlib.pylab as plt
|
10 |
+
|
11 |
+
|
12 |
+
def plot_spectrogram(spectrogram):
|
13 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
14 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
15 |
+
interpolation='none')
|
16 |
+
plt.colorbar(im, ax=ax)
|
17 |
+
|
18 |
+
fig.canvas.draw()
|
19 |
+
plt.close()
|
20 |
+
|
21 |
+
return fig
|
22 |
+
|
23 |
+
|
24 |
+
def init_weights(m, mean=0.0, std=0.01):
|
25 |
+
classname = m.__class__.__name__
|
26 |
+
if classname.find("Conv") != -1:
|
27 |
+
m.weight.data.normal_(mean, std)
|
28 |
+
|
29 |
+
|
30 |
+
def apply_weight_norm(m):
|
31 |
+
classname = m.__class__.__name__
|
32 |
+
if classname.find("Conv") != -1:
|
33 |
+
weight_norm(m)
|
34 |
+
|
35 |
+
|
36 |
+
def get_padding(kernel_size, dilation=1):
|
37 |
+
return int((kernel_size*dilation - dilation)/2)
|
38 |
+
|
39 |
+
|
40 |
+
def load_checkpoint(filepath, device):
|
41 |
+
assert os.path.isfile(filepath)
|
42 |
+
print("Loading '{}'".format(filepath))
|
43 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
44 |
+
print("Complete.")
|
45 |
+
return checkpoint_dict
|
46 |
+
|
47 |
+
|
48 |
+
def save_checkpoint(filepath, obj):
|
49 |
+
print("Saving checkpoint to {}".format(filepath))
|
50 |
+
torch.save(obj, filepath)
|
51 |
+
print("Complete.")
|
52 |
+
|
53 |
+
|
54 |
+
def scan_checkpoint(cp_dir, prefix):
|
55 |
+
pattern = os.path.join(cp_dir, prefix + '????????')
|
56 |
+
cp_list = glob.glob(pattern)
|
57 |
+
if len(cp_list) == 0:
|
58 |
+
return None
|
59 |
+
return sorted(cp_list)[-1]
|
60 |
+
|
diff_ttsg/models/__init__.py
ADDED
File without changes
|
diff_ttsg/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (159 Bytes). View file
|
|
diff_ttsg/models/__pycache__/diff_ttsg.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
diff_ttsg/models/components/__init__.py
ADDED
File without changes
|
diff_ttsg/models/components/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (170 Bytes). View file
|
|
diff_ttsg/models/components/__pycache__/diffusion.cpython-310.pyc
ADDED
Binary file (12.6 kB). View file
|
|
diff_ttsg/models/components/__pycache__/text_encoder.cpython-310.pyc
ADDED
Binary file (12.3 kB). View file
|
|
diff_ttsg/models/components/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (6.03 kB). View file
|
|
diff_ttsg/models/components/diffusion.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
|
2 |
+
# This program is free software; you can redistribute it and/or modify
|
3 |
+
# it under the terms of the MIT License.
|
4 |
+
# This program is distributed in the hope that it will be useful,
|
5 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
6 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
7 |
+
# MIT License for more details.
|
8 |
+
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from diffusers import UNet1DModel
|
13 |
+
from einops import pack, rearrange
|
14 |
+
|
15 |
+
|
16 |
+
class Mish(torch.nn.Module):
|
17 |
+
def forward(self, x):
|
18 |
+
return x * torch.tanh(torch.nn.functional.softplus(x))
|
19 |
+
|
20 |
+
|
21 |
+
class Upsample(torch.nn.Module):
|
22 |
+
def __init__(self, dim):
|
23 |
+
super(Upsample, self).__init__()
|
24 |
+
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return self.conv(x)
|
28 |
+
|
29 |
+
|
30 |
+
class Downsample(torch.nn.Module):
|
31 |
+
def __init__(self, dim):
|
32 |
+
super(Downsample, self).__init__()
|
33 |
+
self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
return self.conv(x)
|
37 |
+
|
38 |
+
|
39 |
+
class Rezero(torch.nn.Module):
|
40 |
+
def __init__(self, fn):
|
41 |
+
super(Rezero, self).__init__()
|
42 |
+
self.fn = fn
|
43 |
+
self.g = torch.nn.Parameter(torch.zeros(1))
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
return self.fn(x) * self.g
|
47 |
+
|
48 |
+
|
49 |
+
class Block(torch.nn.Module):
|
50 |
+
def __init__(self, dim, dim_out, groups=8):
|
51 |
+
super(Block, self).__init__()
|
52 |
+
self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
|
53 |
+
padding=1), torch.nn.GroupNorm(
|
54 |
+
groups, dim_out), Mish())
|
55 |
+
|
56 |
+
def forward(self, x, mask):
|
57 |
+
output = self.block(x * mask)
|
58 |
+
return output * mask
|
59 |
+
|
60 |
+
|
61 |
+
class ResnetBlock(torch.nn.Module):
|
62 |
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
63 |
+
super(ResnetBlock, self).__init__()
|
64 |
+
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
|
65 |
+
dim_out))
|
66 |
+
|
67 |
+
self.block1 = Block(dim, dim_out, groups=groups)
|
68 |
+
self.block2 = Block(dim_out, dim_out, groups=groups)
|
69 |
+
if dim != dim_out:
|
70 |
+
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
|
71 |
+
else:
|
72 |
+
self.res_conv = torch.nn.Identity()
|
73 |
+
|
74 |
+
def forward(self, x, mask, time_emb):
|
75 |
+
h = self.block1(x, mask)
|
76 |
+
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
|
77 |
+
h = self.block2(h, mask)
|
78 |
+
output = h + self.res_conv(x * mask)
|
79 |
+
return output
|
80 |
+
|
81 |
+
|
82 |
+
class LinearAttention(torch.nn.Module):
|
83 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
84 |
+
super(LinearAttention, self).__init__()
|
85 |
+
self.heads = heads
|
86 |
+
hidden_dim = dim_head * heads
|
87 |
+
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
88 |
+
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
b, c, h, w = x.shape
|
92 |
+
qkv = self.to_qkv(x)
|
93 |
+
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)',
|
94 |
+
heads = self.heads, qkv=3)
|
95 |
+
k = k.softmax(dim=-1)
|
96 |
+
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
97 |
+
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
98 |
+
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
|
99 |
+
heads=self.heads, h=h, w=w)
|
100 |
+
return self.to_out(out)
|
101 |
+
|
102 |
+
|
103 |
+
class Residual(torch.nn.Module):
|
104 |
+
def __init__(self, fn):
|
105 |
+
super(Residual, self).__init__()
|
106 |
+
self.fn = fn
|
107 |
+
|
108 |
+
def forward(self, x, *args, **kwargs):
|
109 |
+
output = self.fn(x, *args, **kwargs) + x
|
110 |
+
return output
|
111 |
+
|
112 |
+
|
113 |
+
class UNet1DDiffuser(torch.nn.Module):
|
114 |
+
def __init__(self, in_channels=90, out_channels=45, block_out_channels=(256, 512)):
|
115 |
+
super(UNet1DDiffuser, self).__init__()
|
116 |
+
|
117 |
+
self.unet = UNet1DModel(
|
118 |
+
in_channels=in_channels,
|
119 |
+
out_channels=out_channels,
|
120 |
+
down_block_types = ("DownBlock1DNoSkip", "AttnDownBlock1D"),
|
121 |
+
up_block_types = ("AttnUpBlock1D", "UpBlock1DNoSkip"),
|
122 |
+
mid_block_type = "UNetMidBlock1D",
|
123 |
+
block_out_channels=block_out_channels,
|
124 |
+
use_timestep_embedding=True,
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
def forward(self, x, mask, mu, t, spk=None):
|
129 |
+
x = pack([x, mu], "b * t")[0]
|
130 |
+
|
131 |
+
return self.unet(x, t).sample * mask
|
132 |
+
|
133 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
134 |
+
def __init__(self, dim):
|
135 |
+
super(SinusoidalPosEmb, self).__init__()
|
136 |
+
self.dim = dim
|
137 |
+
|
138 |
+
def forward(self, x, scale=1000):
|
139 |
+
device = x.device
|
140 |
+
half_dim = self.dim // 2
|
141 |
+
emb = math.log(10000) / (half_dim - 1)
|
142 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
143 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
144 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
145 |
+
return emb
|
146 |
+
|
147 |
+
|
148 |
+
class GradLogPEstimator2d(torch.nn.Module):
|
149 |
+
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8,
|
150 |
+
n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
|
151 |
+
super(GradLogPEstimator2d, self).__init__()
|
152 |
+
self.dim = dim
|
153 |
+
self.dim_mults = dim_mults
|
154 |
+
self.groups = groups
|
155 |
+
self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
|
156 |
+
self.spk_emb_dim = spk_emb_dim
|
157 |
+
self.pe_scale = pe_scale
|
158 |
+
|
159 |
+
if n_spks > 1:
|
160 |
+
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
|
161 |
+
torch.nn.Linear(spk_emb_dim * 4, n_feats))
|
162 |
+
self.time_pos_emb = SinusoidalPosEmb(dim)
|
163 |
+
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
|
164 |
+
torch.nn.Linear(dim * 4, dim))
|
165 |
+
|
166 |
+
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
|
167 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
168 |
+
self.downs = torch.nn.ModuleList([])
|
169 |
+
self.ups = torch.nn.ModuleList([])
|
170 |
+
num_resolutions = len(in_out)
|
171 |
+
|
172 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
173 |
+
is_last = ind >= (num_resolutions - 1)
|
174 |
+
self.downs.append(torch.nn.ModuleList([
|
175 |
+
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
|
176 |
+
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
|
177 |
+
Residual(Rezero(LinearAttention(dim_out))),
|
178 |
+
Downsample(dim_out) if not is_last else torch.nn.Identity()]))
|
179 |
+
|
180 |
+
mid_dim = dims[-1]
|
181 |
+
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
|
182 |
+
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
|
183 |
+
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
|
184 |
+
|
185 |
+
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
186 |
+
self.ups.append(torch.nn.ModuleList([
|
187 |
+
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
|
188 |
+
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
|
189 |
+
Residual(Rezero(LinearAttention(dim_in))),
|
190 |
+
Upsample(dim_in)]))
|
191 |
+
self.final_block = Block(dim, dim)
|
192 |
+
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
|
193 |
+
|
194 |
+
def forward(self, x, mask, mu, t, spk=None):
|
195 |
+
if not isinstance(spk, type(None)):
|
196 |
+
s = self.spk_mlp(spk)
|
197 |
+
|
198 |
+
t = self.time_pos_emb(t, scale=self.pe_scale)
|
199 |
+
t = self.mlp(t)
|
200 |
+
|
201 |
+
if self.n_spks < 2:
|
202 |
+
x = torch.stack([mu, x], 1)
|
203 |
+
else:
|
204 |
+
s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
|
205 |
+
x = torch.stack([mu, x, s], 1)
|
206 |
+
mask = mask.unsqueeze(1)
|
207 |
+
|
208 |
+
hiddens = []
|
209 |
+
masks = [mask]
|
210 |
+
for resnet1, resnet2, attn, downsample in self.downs:
|
211 |
+
mask_down = masks[-1]
|
212 |
+
x = resnet1(x, mask_down, t)
|
213 |
+
x = resnet2(x, mask_down, t)
|
214 |
+
x = attn(x)
|
215 |
+
hiddens.append(x)
|
216 |
+
x = downsample(x * mask_down)
|
217 |
+
masks.append(mask_down[:, :, :, ::2])
|
218 |
+
masks = masks[:-1]
|
219 |
+
mask_mid = masks[-1]
|
220 |
+
x = self.mid_block1(x, mask_mid, t)
|
221 |
+
x = self.mid_attn(x)
|
222 |
+
x = self.mid_block2(x, mask_mid, t)
|
223 |
+
|
224 |
+
for resnet1, resnet2, attn, upsample in self.ups:
|
225 |
+
mask_up = masks.pop()
|
226 |
+
x = torch.cat((x, hiddens.pop()), dim=1)
|
227 |
+
x = resnet1(x, mask_up, t)
|
228 |
+
x = resnet2(x, mask_up, t)
|
229 |
+
x = attn(x)
|
230 |
+
x = upsample(x * mask_up)
|
231 |
+
|
232 |
+
x = self.final_block(x, mask)
|
233 |
+
output = self.final_conv(x * mask)
|
234 |
+
|
235 |
+
return (output * mask).squeeze(1)
|
236 |
+
|
237 |
+
|
238 |
+
def get_noise(t, beta_init, beta_term, cumulative=False):
|
239 |
+
if cumulative:
|
240 |
+
noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
|
241 |
+
else:
|
242 |
+
noise = beta_init + (beta_term - beta_init)*t
|
243 |
+
return noise
|
244 |
+
|
245 |
+
|
246 |
+
class Diffusion(torch.nn.Module):
|
247 |
+
def __init__(self, n_feats, dim,
|
248 |
+
n_spks=1, spk_emb_dim=64,
|
249 |
+
beta_min=0.05, beta_max=20, pe_scale=1000):
|
250 |
+
super(Diffusion, self).__init__()
|
251 |
+
self.n_feats = n_feats
|
252 |
+
self.dim = dim
|
253 |
+
self.n_spks = n_spks
|
254 |
+
self.spk_emb_dim = spk_emb_dim
|
255 |
+
self.beta_min = beta_min
|
256 |
+
self.beta_max = beta_max
|
257 |
+
self.pe_scale = pe_scale
|
258 |
+
|
259 |
+
self.estimator = GradLogPEstimator2d(dim, n_spks=n_spks,
|
260 |
+
spk_emb_dim=spk_emb_dim,
|
261 |
+
pe_scale=pe_scale)
|
262 |
+
|
263 |
+
def forward_diffusion(self, x0, mask, mu, t):
|
264 |
+
time = t.unsqueeze(-1).unsqueeze(-1)
|
265 |
+
cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
|
266 |
+
mean = x0*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))
|
267 |
+
variance = 1.0 - torch.exp(-cum_noise)
|
268 |
+
z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device,
|
269 |
+
requires_grad=False)
|
270 |
+
xt = mean + z * torch.sqrt(variance)
|
271 |
+
return xt * mask, z * mask
|
272 |
+
|
273 |
+
@torch.no_grad()
|
274 |
+
def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
|
275 |
+
h = 1.0 / n_timesteps
|
276 |
+
xt = z * mask
|
277 |
+
for i in range(n_timesteps):
|
278 |
+
t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype,
|
279 |
+
device=z.device)
|
280 |
+
time = t.unsqueeze(-1).unsqueeze(-1)
|
281 |
+
noise_t = get_noise(time, self.beta_min, self.beta_max,
|
282 |
+
cumulative=False)
|
283 |
+
if stoc: # adds stochastic term
|
284 |
+
dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
|
285 |
+
dxt_det = dxt_det * noise_t * h
|
286 |
+
dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
|
287 |
+
requires_grad=False)
|
288 |
+
dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
|
289 |
+
dxt = dxt_det + dxt_stoc
|
290 |
+
else:
|
291 |
+
dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
|
292 |
+
dxt = dxt * noise_t * h
|
293 |
+
xt = (xt - dxt) * mask
|
294 |
+
return xt
|
295 |
+
|
296 |
+
@torch.no_grad()
|
297 |
+
def forward(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
|
298 |
+
return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk)
|
299 |
+
|
300 |
+
def loss_t(self, x0, mask, mu, t, spk=None):
|
301 |
+
xt, z = self.forward_diffusion(x0, mask, mu, t)
|
302 |
+
time = t.unsqueeze(-1).unsqueeze(-1) # t =[0.6215, 0.0191, 0.0391]
|
303 |
+
cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
|
304 |
+
noise_estimation = self.estimator(xt, mask, mu, t, spk) # xt = [3, 80, 172], mask=[3, 1, 172], mu=[3, 80, 172], t=[3]
|
305 |
+
noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))
|
306 |
+
loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feats)
|
307 |
+
return loss, xt
|
308 |
+
|
309 |
+
def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5):
|
310 |
+
t = torch.rand(x0.shape[0], dtype=x0.dtype, device=x0.device,
|
311 |
+
requires_grad=False)
|
312 |
+
t = torch.clamp(t, offset, 1.0 - offset)
|
313 |
+
return self.loss_t(x0, mask, mu, t, spk)
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
class Diffusion_Motion(torch.nn.Module):
|
318 |
+
def __init__(self, in_channels, motion_decoder_channels=(256, 256), beta_min=0.05, beta_max=20):
|
319 |
+
super(Diffusion_Motion, self).__init__()
|
320 |
+
self.in_channels = in_channels
|
321 |
+
self.beta_min = beta_min
|
322 |
+
self.beta_max = beta_max
|
323 |
+
|
324 |
+
self.estimator = UNet1DDiffuser(block_out_channels=motion_decoder_channels)
|
325 |
+
|
326 |
+
def forward_diffusion(self, x0, mask, mu, t):
|
327 |
+
time = t.unsqueeze(-1).unsqueeze(-1)
|
328 |
+
cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
|
329 |
+
mean = x0*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))
|
330 |
+
variance = 1.0 - torch.exp(-cum_noise)
|
331 |
+
z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device,
|
332 |
+
requires_grad=False)
|
333 |
+
xt = mean + z * torch.sqrt(variance)
|
334 |
+
return xt * mask, z * mask
|
335 |
+
|
336 |
+
@torch.no_grad()
|
337 |
+
def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
|
338 |
+
h = 1.0 / n_timesteps
|
339 |
+
xt = z * mask
|
340 |
+
for i in range(n_timesteps):
|
341 |
+
t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype,
|
342 |
+
device=z.device)
|
343 |
+
time = t.unsqueeze(-1).unsqueeze(-1)
|
344 |
+
noise_t = get_noise(time, self.beta_min, self.beta_max,
|
345 |
+
cumulative=False)
|
346 |
+
if stoc: # adds stochastic term
|
347 |
+
dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
|
348 |
+
dxt_det = dxt_det * noise_t * h
|
349 |
+
dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
|
350 |
+
requires_grad=False)
|
351 |
+
dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
|
352 |
+
dxt = dxt_det + dxt_stoc
|
353 |
+
else:
|
354 |
+
dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
|
355 |
+
dxt = dxt * noise_t * h
|
356 |
+
xt = (xt - dxt) * mask
|
357 |
+
return xt
|
358 |
+
|
359 |
+
@torch.no_grad()
|
360 |
+
def forward(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
|
361 |
+
return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk)
|
362 |
+
|
363 |
+
def loss_t(self, x0, mask, mu, t, spk=None):
|
364 |
+
xt, z = self.forward_diffusion(x0, mask, mu, t)
|
365 |
+
time = t.unsqueeze(-1).unsqueeze(-1) # t =[0.6215, 0.0191, 0.0391]
|
366 |
+
cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
|
367 |
+
noise_estimation = self.estimator(xt, mask, mu, t, spk) # xt = [3, 80, 172], mask=[3, 1, 172], mu=[3, 80, 172], t=[3]
|
368 |
+
noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))
|
369 |
+
loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.in_channels)
|
370 |
+
return loss, xt
|
371 |
+
|
372 |
+
def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5):
|
373 |
+
t = torch.rand(x0.shape[0], dtype=x0.dtype, device=x0.device,
|
374 |
+
requires_grad=False)
|
375 |
+
t = torch.clamp(t, offset, 1.0 - offset)
|
376 |
+
return self.loss_t(x0, mask, mu, t, spk)
|
diff_ttsg/models/components/text_encoder.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/jaywalnut310/glow-tts """
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from conformer import ConformerBlock
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
from diff_ttsg.models.components.transformer import FFTransformer
|
11 |
+
from diff_ttsg.utils.model import convert_pad_shape, sequence_mask
|
12 |
+
|
13 |
+
|
14 |
+
class LayerNorm(nn.Module):
|
15 |
+
def __init__(self, channels, eps=1e-4):
|
16 |
+
super(LayerNorm, self).__init__()
|
17 |
+
self.channels = channels
|
18 |
+
self.eps = eps
|
19 |
+
|
20 |
+
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
21 |
+
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
n_dims = len(x.shape)
|
25 |
+
mean = torch.mean(x, 1, keepdim=True)
|
26 |
+
variance = torch.mean((x - mean)**2, 1, keepdim=True)
|
27 |
+
|
28 |
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
29 |
+
|
30 |
+
shape = [1, -1] + [1] * (n_dims - 2)
|
31 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class ConvReluNorm(nn.Module):
|
36 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
|
37 |
+
n_layers, p_dropout):
|
38 |
+
super(ConvReluNorm, self).__init__()
|
39 |
+
self.in_channels = in_channels
|
40 |
+
self.hidden_channels = hidden_channels
|
41 |
+
self.out_channels = out_channels
|
42 |
+
self.kernel_size = kernel_size
|
43 |
+
self.n_layers = n_layers
|
44 |
+
self.p_dropout = p_dropout
|
45 |
+
|
46 |
+
self.conv_layers = torch.nn.ModuleList()
|
47 |
+
self.norm_layers = torch.nn.ModuleList()
|
48 |
+
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels,
|
49 |
+
kernel_size, padding=kernel_size//2))
|
50 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
51 |
+
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
|
52 |
+
for _ in range(n_layers - 1):
|
53 |
+
self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels,
|
54 |
+
kernel_size, padding=kernel_size//2))
|
55 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
56 |
+
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
57 |
+
self.proj.weight.data.zero_()
|
58 |
+
self.proj.bias.data.zero_()
|
59 |
+
|
60 |
+
def forward(self, x, x_mask):
|
61 |
+
x_org = x
|
62 |
+
for i in range(self.n_layers):
|
63 |
+
x = self.conv_layers[i](x * x_mask)
|
64 |
+
x = self.norm_layers[i](x)
|
65 |
+
x = self.relu_drop(x)
|
66 |
+
x = x_org + self.proj(x)
|
67 |
+
return x * x_mask
|
68 |
+
|
69 |
+
|
70 |
+
class DurationPredictor(nn.Module):
|
71 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
|
72 |
+
super(DurationPredictor, self).__init__()
|
73 |
+
self.in_channels = in_channels
|
74 |
+
self.filter_channels = filter_channels
|
75 |
+
self.p_dropout = p_dropout
|
76 |
+
|
77 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
78 |
+
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels,
|
79 |
+
kernel_size, padding=kernel_size//2)
|
80 |
+
self.norm_1 = LayerNorm(filter_channels)
|
81 |
+
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels,
|
82 |
+
kernel_size, padding=kernel_size//2)
|
83 |
+
self.norm_2 = LayerNorm(filter_channels)
|
84 |
+
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
85 |
+
|
86 |
+
def forward(self, x, x_mask):
|
87 |
+
x = self.conv_1(x * x_mask)
|
88 |
+
x = torch.relu(x)
|
89 |
+
x = self.norm_1(x)
|
90 |
+
x = self.drop(x)
|
91 |
+
x = self.conv_2(x * x_mask)
|
92 |
+
x = torch.relu(x)
|
93 |
+
x = self.norm_2(x)
|
94 |
+
x = self.drop(x)
|
95 |
+
x = self.proj(x * x_mask)
|
96 |
+
return x * x_mask
|
97 |
+
|
98 |
+
|
99 |
+
class MultiHeadAttention(nn.Module):
|
100 |
+
def __init__(self, channels, out_channels, n_heads, window_size=None,
|
101 |
+
heads_share=True, p_dropout=0.0, proximal_bias=False,
|
102 |
+
proximal_init=False):
|
103 |
+
super(MultiHeadAttention, self).__init__()
|
104 |
+
assert channels % n_heads == 0
|
105 |
+
|
106 |
+
self.channels = channels
|
107 |
+
self.out_channels = out_channels
|
108 |
+
self.n_heads = n_heads
|
109 |
+
self.window_size = window_size
|
110 |
+
self.heads_share = heads_share
|
111 |
+
self.proximal_bias = proximal_bias
|
112 |
+
self.p_dropout = p_dropout
|
113 |
+
self.attn = None
|
114 |
+
|
115 |
+
self.k_channels = channels // n_heads
|
116 |
+
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
117 |
+
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
118 |
+
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
119 |
+
if window_size is not None:
|
120 |
+
n_heads_rel = 1 if heads_share else n_heads
|
121 |
+
rel_stddev = self.k_channels**-0.5
|
122 |
+
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel,
|
123 |
+
window_size * 2 + 1, self.k_channels) * rel_stddev)
|
124 |
+
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel,
|
125 |
+
window_size * 2 + 1, self.k_channels) * rel_stddev)
|
126 |
+
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
127 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
128 |
+
|
129 |
+
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
130 |
+
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
131 |
+
if proximal_init:
|
132 |
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
133 |
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
134 |
+
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
135 |
+
|
136 |
+
def forward(self, x, c, attn_mask=None):
|
137 |
+
q = self.conv_q(x)
|
138 |
+
k = self.conv_k(c)
|
139 |
+
v = self.conv_v(c)
|
140 |
+
|
141 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
142 |
+
|
143 |
+
x = self.conv_o(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
def attention(self, query, key, value, mask=None):
|
147 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
148 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
149 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
150 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
151 |
+
|
152 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
153 |
+
if self.window_size is not None:
|
154 |
+
assert t_s == t_t, "Relative attention is only available for self-attention."
|
155 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
156 |
+
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
157 |
+
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
158 |
+
scores_local = rel_logits / math.sqrt(self.k_channels)
|
159 |
+
scores = scores + scores_local
|
160 |
+
if self.proximal_bias:
|
161 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
162 |
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device,
|
163 |
+
dtype=scores.dtype)
|
164 |
+
if mask is not None:
|
165 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
166 |
+
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
167 |
+
p_attn = self.drop(p_attn)
|
168 |
+
output = torch.matmul(p_attn, value)
|
169 |
+
if self.window_size is not None:
|
170 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
171 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
172 |
+
output = output + self._matmul_with_relative_values(relative_weights,
|
173 |
+
value_relative_embeddings)
|
174 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
|
175 |
+
return output, p_attn
|
176 |
+
|
177 |
+
def _matmul_with_relative_values(self, x, y):
|
178 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
179 |
+
return ret
|
180 |
+
|
181 |
+
def _matmul_with_relative_keys(self, x, y):
|
182 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
183 |
+
return ret
|
184 |
+
|
185 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
186 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
187 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
188 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
189 |
+
if pad_length > 0:
|
190 |
+
padded_relative_embeddings = torch.nn.functional.pad(
|
191 |
+
relative_embeddings, convert_pad_shape([[0, 0],
|
192 |
+
[pad_length, pad_length], [0, 0]]))
|
193 |
+
else:
|
194 |
+
padded_relative_embeddings = relative_embeddings
|
195 |
+
used_relative_embeddings = padded_relative_embeddings[:,
|
196 |
+
slice_start_position:slice_end_position]
|
197 |
+
return used_relative_embeddings
|
198 |
+
|
199 |
+
def _relative_position_to_absolute_position(self, x):
|
200 |
+
batch, heads, length, _ = x.size()
|
201 |
+
x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
|
202 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
203 |
+
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))
|
204 |
+
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
|
205 |
+
return x_final
|
206 |
+
|
207 |
+
def _absolute_position_to_relative_position(self, x):
|
208 |
+
batch, heads, length, _ = x.size()
|
209 |
+
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
|
210 |
+
x_flat = x.view([batch, heads, length**2 + length*(length - 1)])
|
211 |
+
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
212 |
+
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
|
213 |
+
return x_final
|
214 |
+
|
215 |
+
def _attention_bias_proximal(self, length):
|
216 |
+
r = torch.arange(length, dtype=torch.float32)
|
217 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
218 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
219 |
+
|
220 |
+
|
221 |
+
class FFN(nn.Module):
|
222 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size,
|
223 |
+
p_dropout=0.0):
|
224 |
+
super(FFN, self).__init__()
|
225 |
+
self.in_channels = in_channels
|
226 |
+
self.out_channels = out_channels
|
227 |
+
self.filter_channels = filter_channels
|
228 |
+
self.kernel_size = kernel_size
|
229 |
+
self.p_dropout = p_dropout
|
230 |
+
|
231 |
+
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size,
|
232 |
+
padding=kernel_size//2)
|
233 |
+
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size,
|
234 |
+
padding=kernel_size//2)
|
235 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
236 |
+
|
237 |
+
def forward(self, x, x_mask):
|
238 |
+
x = self.conv_1(x * x_mask)
|
239 |
+
x = torch.relu(x)
|
240 |
+
x = self.drop(x)
|
241 |
+
x = self.conv_2(x * x_mask)
|
242 |
+
return x * x_mask
|
243 |
+
|
244 |
+
|
245 |
+
class Encoder(nn.Module):
|
246 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers,
|
247 |
+
kernel_size=1, p_dropout=0.0, window_size=None, **kwargs):
|
248 |
+
super(Encoder, self).__init__()
|
249 |
+
self.hidden_channels = hidden_channels
|
250 |
+
self.filter_channels = filter_channels
|
251 |
+
self.n_heads = n_heads
|
252 |
+
self.n_layers = n_layers
|
253 |
+
self.kernel_size = kernel_size
|
254 |
+
self.p_dropout = p_dropout
|
255 |
+
self.window_size = window_size
|
256 |
+
|
257 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
258 |
+
self.attn_layers = torch.nn.ModuleList()
|
259 |
+
self.norm_layers_1 = torch.nn.ModuleList()
|
260 |
+
self.ffn_layers = torch.nn.ModuleList()
|
261 |
+
self.norm_layers_2 = torch.nn.ModuleList()
|
262 |
+
for _ in range(self.n_layers):
|
263 |
+
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
|
264 |
+
n_heads, window_size=window_size, p_dropout=p_dropout))
|
265 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
266 |
+
self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
|
267 |
+
filter_channels, kernel_size, p_dropout=p_dropout))
|
268 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
269 |
+
|
270 |
+
def forward(self, x, x_mask):
|
271 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
272 |
+
for i in range(self.n_layers):
|
273 |
+
x = x * x_mask
|
274 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
275 |
+
y = self.drop(y)
|
276 |
+
x = self.norm_layers_1[i](x + y)
|
277 |
+
y = self.ffn_layers[i](x, x_mask)
|
278 |
+
y = self.drop(y)
|
279 |
+
x = self.norm_layers_2[i](x + y)
|
280 |
+
x = x * x_mask
|
281 |
+
return x
|
282 |
+
|
283 |
+
|
284 |
+
class TextEncoder(nn.Module):
|
285 |
+
def __init__(self, n_vocab, n_feats, n_channels, filter_channels,
|
286 |
+
filter_channels_dp, n_heads, n_layers, kernel_size,
|
287 |
+
p_dropout, window_size=None, spk_emb_dim=64, n_spks=1, encoder_type=None):
|
288 |
+
super(TextEncoder, self).__init__()
|
289 |
+
self.n_vocab = n_vocab
|
290 |
+
self.n_feats = n_feats
|
291 |
+
self.n_channels = n_channels
|
292 |
+
self.filter_channels = filter_channels
|
293 |
+
self.filter_channels_dp = filter_channels_dp
|
294 |
+
self.n_heads = n_heads
|
295 |
+
self.n_layers = n_layers
|
296 |
+
self.kernel_size = kernel_size
|
297 |
+
self.p_dropout = p_dropout
|
298 |
+
self.window_size = window_size
|
299 |
+
self.spk_emb_dim = spk_emb_dim
|
300 |
+
self.n_spks = n_spks
|
301 |
+
|
302 |
+
self.emb = torch.nn.Embedding(n_vocab, n_channels)
|
303 |
+
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
|
304 |
+
|
305 |
+
self.prenet = ConvReluNorm(n_channels, n_channels, n_channels,
|
306 |
+
kernel_size=5, n_layers=3, p_dropout=0.5)
|
307 |
+
if encoder_type == "default":
|
308 |
+
self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers,
|
309 |
+
kernel_size, p_dropout, window_size=window_size)
|
310 |
+
elif encoder_type == "myencoder":
|
311 |
+
self.encoder = FFTransformer(
|
312 |
+
n_layers, n_heads, n_channels + (spk_emb_dim if n_spks > 1 else 0), 64, 1024, kernel_size,
|
313 |
+
p_dropout, p_dropout, rel_attention=False, rel_window_size=window_size
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
raise ValueError(f"Unknown encoder type: {encoder_type}")
|
317 |
+
|
318 |
+
|
319 |
+
self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1)
|
320 |
+
self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp,
|
321 |
+
kernel_size, p_dropout)
|
322 |
+
|
323 |
+
def forward(self, x, x_lengths, spk=None):
|
324 |
+
x = self.emb(x) * math.sqrt(self.n_channels)
|
325 |
+
x = torch.transpose(x, 1, -1)
|
326 |
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
327 |
+
|
328 |
+
x = self.prenet(x, x_mask)
|
329 |
+
if self.n_spks > 1:
|
330 |
+
x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
331 |
+
x = self.encoder(x, x_mask)
|
332 |
+
mu = self.proj_m(x) * x_mask
|
333 |
+
|
334 |
+
x_dp = torch.detach(x)
|
335 |
+
logw = self.proj_w(x_dp, x_mask)
|
336 |
+
|
337 |
+
return mu, logw, x_mask
|
338 |
+
|
339 |
+
class MuMotionEncoder(nn.Module):
|
340 |
+
def __init__(
|
341 |
+
self,
|
342 |
+
input_channels,
|
343 |
+
output_channels,
|
344 |
+
hidden_channels,
|
345 |
+
d_head,
|
346 |
+
n_layer,
|
347 |
+
n_head,
|
348 |
+
ff_mult,
|
349 |
+
conv_expansion_factor,
|
350 |
+
dropout,
|
351 |
+
dropatt,
|
352 |
+
dropconv,
|
353 |
+
conv_kernel_size,
|
354 |
+
) -> None:
|
355 |
+
super().__init__()
|
356 |
+
|
357 |
+
self.in_projection = nn.Conv1d(input_channels, hidden_channels, 1)
|
358 |
+
self.layers = nn.ModuleList()
|
359 |
+
for _ in range(n_layer):
|
360 |
+
self.layers.append(
|
361 |
+
ConformerBlock(
|
362 |
+
dim=hidden_channels,
|
363 |
+
dim_head=d_head,
|
364 |
+
heads=n_head,
|
365 |
+
ff_mult=ff_mult,
|
366 |
+
conv_expansion_factor=conv_expansion_factor,
|
367 |
+
ff_dropout=dropout,
|
368 |
+
attn_dropout=dropatt,
|
369 |
+
conv_dropout=dropconv,
|
370 |
+
conv_kernel_size=conv_kernel_size,
|
371 |
+
)
|
372 |
+
)
|
373 |
+
|
374 |
+
self.motion_projection = nn.Conv1d(hidden_channels, output_channels, 1)
|
375 |
+
|
376 |
+
def forward(self, x, mask):
|
377 |
+
x = self.in_projection(x)
|
378 |
+
x = rearrange(x, "b c t -> b t c")
|
379 |
+
mask = rearrange(mask, "b 1 t -> b (1 t)").bool()
|
380 |
+
for layer in self.layers:
|
381 |
+
x = layer(x, mask)
|
382 |
+
x = rearrange(x, "b t c -> b c t")
|
383 |
+
x = self.motion_projection(x)
|
384 |
+
return x
|
diff_ttsg/models/components/transformer.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from conformer.conformer import Attention as RelAttention
|
19 |
+
from einops import rearrange
|
20 |
+
|
21 |
+
|
22 |
+
class PositionalEmbedding(nn.Module):
|
23 |
+
def __init__(self, demb):
|
24 |
+
super().__init__()
|
25 |
+
self.demb = demb
|
26 |
+
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
|
27 |
+
self.register_buffer("inv_freq", inv_freq)
|
28 |
+
|
29 |
+
def forward(self, pos_seq, bsz=None):
|
30 |
+
sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0))
|
31 |
+
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
|
32 |
+
if bsz is not None:
|
33 |
+
return pos_emb[None, :, :].expand(bsz, -1, -1)
|
34 |
+
else:
|
35 |
+
return pos_emb[None, :, :]
|
36 |
+
|
37 |
+
|
38 |
+
class PositionwiseConvFF(nn.Module):
|
39 |
+
def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.d_model = d_model
|
43 |
+
self.d_inner = d_inner
|
44 |
+
self.dropout = dropout
|
45 |
+
|
46 |
+
self.CoreNet = nn.Sequential(
|
47 |
+
nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
|
48 |
+
nn.ReLU(),
|
49 |
+
# nn.Dropout(dropout), # worse convergence
|
50 |
+
nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
|
51 |
+
nn.Dropout(dropout),
|
52 |
+
)
|
53 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
54 |
+
self.pre_lnorm = pre_lnorm
|
55 |
+
|
56 |
+
def forward(self, inp):
|
57 |
+
return self._forward(inp)
|
58 |
+
|
59 |
+
def _forward(self, inp):
|
60 |
+
if self.pre_lnorm:
|
61 |
+
# layer normalization + positionwise feed-forward
|
62 |
+
# core_out = inp
|
63 |
+
core_out = self.CoreNet(self.layer_norm(inp).transpose(1, 2))
|
64 |
+
core_out = core_out.transpose(1, 2)
|
65 |
+
|
66 |
+
# residual connection
|
67 |
+
output = core_out + inp
|
68 |
+
else:
|
69 |
+
# positionwise feed-forward
|
70 |
+
core_out = inp.transpose(1, 2)
|
71 |
+
core_out = self.CoreNet(core_out)
|
72 |
+
core_out = core_out.transpose(1, 2)
|
73 |
+
|
74 |
+
# residual connection + layer normalization
|
75 |
+
output = self.layer_norm(inp + core_out).to(inp.dtype)
|
76 |
+
|
77 |
+
return output
|
78 |
+
|
79 |
+
|
80 |
+
class MultiHeadAttn(nn.Module):
|
81 |
+
def __init__(
|
82 |
+
self, n_head, d_model, d_head, dropout, rel_attention, dropatt=0.1, pre_lnorm=True, rel_window_size=10
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
self.n_head = n_head
|
87 |
+
self.d_model = d_model
|
88 |
+
self.d_head = d_head
|
89 |
+
self.scale = 1 / (d_head**0.5)
|
90 |
+
self.pre_lnorm = pre_lnorm
|
91 |
+
self.rel_attention = rel_attention
|
92 |
+
if rel_attention:
|
93 |
+
self.attn = RelAttention(d_model, n_head, d_head, dropout, max_pos_emb=rel_window_size)
|
94 |
+
else:
|
95 |
+
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
|
96 |
+
self.drop = nn.Dropout(dropout)
|
97 |
+
self.dropatt = nn.Dropout(dropatt)
|
98 |
+
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
99 |
+
|
100 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
101 |
+
|
102 |
+
def forward(self, inp, attn_mask=None):
|
103 |
+
return self._forward(inp, attn_mask)
|
104 |
+
|
105 |
+
def _forward(self, inp, attn_mask=None):
|
106 |
+
residual = inp
|
107 |
+
|
108 |
+
if self.pre_lnorm:
|
109 |
+
# layer normalization
|
110 |
+
inp = self.layer_norm(inp)
|
111 |
+
|
112 |
+
if not self.rel_attention:
|
113 |
+
n_head, d_head = self.n_head, self.d_head
|
114 |
+
|
115 |
+
head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
|
116 |
+
head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
|
117 |
+
head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
|
118 |
+
head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
|
119 |
+
|
120 |
+
q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
121 |
+
k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
122 |
+
v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
123 |
+
|
124 |
+
attn_score = torch.bmm(q, k.transpose(1, 2))
|
125 |
+
attn_score.mul_(self.scale)
|
126 |
+
|
127 |
+
if attn_mask is not None:
|
128 |
+
attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
|
129 |
+
attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
|
130 |
+
attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf"))
|
131 |
+
|
132 |
+
attn_prob = F.softmax(attn_score, dim=2)
|
133 |
+
attn_prob = self.dropatt(attn_prob)
|
134 |
+
attn_vec = torch.bmm(attn_prob, v)
|
135 |
+
|
136 |
+
attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
|
137 |
+
attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), n_head * d_head)
|
138 |
+
|
139 |
+
# linear projection
|
140 |
+
attn_out = self.o_net(attn_vec)
|
141 |
+
attn_out = self.drop(attn_out)
|
142 |
+
else:
|
143 |
+
attn_out = self.attn(inp, mask=attn_mask)
|
144 |
+
|
145 |
+
if self.pre_lnorm:
|
146 |
+
# residual connection
|
147 |
+
output = residual + attn_out
|
148 |
+
else:
|
149 |
+
# residual connection + layer normalization
|
150 |
+
output = self.layer_norm(residual + attn_out)
|
151 |
+
|
152 |
+
output = output.to(attn_out.dtype)
|
153 |
+
|
154 |
+
return output
|
155 |
+
|
156 |
+
|
157 |
+
class TransformerLayer(nn.Module):
|
158 |
+
def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout, **kwargs):
|
159 |
+
super().__init__()
|
160 |
+
|
161 |
+
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
|
162 |
+
self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout, pre_lnorm=kwargs.get("pre_lnorm"))
|
163 |
+
|
164 |
+
def forward(self, dec_inp, mask=None):
|
165 |
+
output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
|
166 |
+
output *= mask
|
167 |
+
output = self.pos_ff(output)
|
168 |
+
output *= mask
|
169 |
+
return output
|
170 |
+
|
171 |
+
|
172 |
+
class FFTransformer(nn.Module):
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
n_layer,
|
176 |
+
n_head,
|
177 |
+
hidden_channels,
|
178 |
+
d_head,
|
179 |
+
d_inner,
|
180 |
+
kernel_size,
|
181 |
+
dropout,
|
182 |
+
dropatt,
|
183 |
+
dropemb=0.0,
|
184 |
+
embed_input=False,
|
185 |
+
n_embed=None,
|
186 |
+
d_embed=None,
|
187 |
+
padding_idx=0,
|
188 |
+
pre_lnorm=True,
|
189 |
+
rel_attention=True,
|
190 |
+
rel_window_size=10,
|
191 |
+
):
|
192 |
+
super().__init__()
|
193 |
+
self.d_model = hidden_channels
|
194 |
+
self.n_head = n_head
|
195 |
+
self.d_head = d_head
|
196 |
+
self.padding_idx = padding_idx
|
197 |
+
|
198 |
+
if embed_input:
|
199 |
+
self.word_emb = nn.Embedding(n_embed, d_embed or hidden_channels, padding_idx=self.padding_idx)
|
200 |
+
else:
|
201 |
+
self.word_emb = None
|
202 |
+
|
203 |
+
self.rel_attention = rel_attention
|
204 |
+
|
205 |
+
if not rel_attention:
|
206 |
+
self.pos_emb = PositionalEmbedding(self.d_model)
|
207 |
+
|
208 |
+
self.drop = nn.Dropout(dropemb)
|
209 |
+
self.layers = nn.ModuleList()
|
210 |
+
|
211 |
+
for _ in range(n_layer):
|
212 |
+
self.layers.append(
|
213 |
+
TransformerLayer(
|
214 |
+
n_head,
|
215 |
+
hidden_channels,
|
216 |
+
d_head,
|
217 |
+
d_inner,
|
218 |
+
kernel_size,
|
219 |
+
dropout,
|
220 |
+
dropatt=dropatt,
|
221 |
+
pre_lnorm=pre_lnorm,
|
222 |
+
rel_attention=rel_attention,
|
223 |
+
rel_window_size=rel_window_size,
|
224 |
+
)
|
225 |
+
)
|
226 |
+
|
227 |
+
def forward(self, dec_inp, mask=None, conditioning=0):
|
228 |
+
inp = dec_inp.transpose(1, 2)
|
229 |
+
mask = mask.bool().squeeze(1).unsqueeze(2)
|
230 |
+
# if self.word_emb is None:
|
231 |
+
# inp = dec_inp
|
232 |
+
# mask = sequence_mask(seq_lens, inp.shape[1], device=seq_lens.device, dtype=seq_lens.dtype).unsqueeze(2)
|
233 |
+
# else:
|
234 |
+
# inp = self.word_emb(dec_inp)
|
235 |
+
# # [bsz x L x 1]
|
236 |
+
# mask = (dec_inp != self.padding_idx).unsqueeze(2)
|
237 |
+
|
238 |
+
if not self.rel_attention:
|
239 |
+
pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype)
|
240 |
+
pos_emb = self.pos_emb(pos_seq) * mask
|
241 |
+
else:
|
242 |
+
pos_emb = 0
|
243 |
+
|
244 |
+
out = self.drop(inp + pos_emb + conditioning)
|
245 |
+
|
246 |
+
for layer in self.layers:
|
247 |
+
out = layer(out, mask=mask)
|
248 |
+
|
249 |
+
# out = self.drop(out)
|
250 |
+
return rearrange(out, "b l h -> b h l")
|
diff_ttsg/models/diff_ttsg.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from lightning import LightningModule
|
7 |
+
|
8 |
+
import diff_ttsg.utils.monotonic_align as monotonic_align
|
9 |
+
from diff_ttsg import utils
|
10 |
+
from diff_ttsg.models.components.diffusion import Diffusion, Diffusion_Motion
|
11 |
+
from diff_ttsg.models.components.text_encoder import (MuMotionEncoder,
|
12 |
+
TextEncoder)
|
13 |
+
from diff_ttsg.utils.model import (denormalize, duration_loss,
|
14 |
+
fix_len_compatibility, generate_path,
|
15 |
+
sequence_mask)
|
16 |
+
from diff_ttsg.utils.utils import plot_tensor
|
17 |
+
|
18 |
+
log = utils.get_pylogger(__name__)
|
19 |
+
|
20 |
+
class Diff_TTSG(LightningModule):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
n_vocab,
|
24 |
+
n_spks,
|
25 |
+
spk_emb_dim,
|
26 |
+
n_enc_channels,
|
27 |
+
filter_channels,
|
28 |
+
filter_channels_dp,
|
29 |
+
n_heads,
|
30 |
+
n_enc_layers,
|
31 |
+
enc_kernel,
|
32 |
+
enc_dropout,
|
33 |
+
window_size,
|
34 |
+
n_feats,
|
35 |
+
n_motions,
|
36 |
+
dec_dim,
|
37 |
+
beta_min,
|
38 |
+
beta_max,
|
39 |
+
pe_scale,
|
40 |
+
mu_motion_encoder_params,
|
41 |
+
motion_reduction_factor,
|
42 |
+
motion_decoder_channels,
|
43 |
+
data_statistics,
|
44 |
+
out_size,
|
45 |
+
only_speech=False,
|
46 |
+
encoder_type="default",
|
47 |
+
optimizer=None
|
48 |
+
):
|
49 |
+
super(Diff_TTSG, self).__init__()
|
50 |
+
|
51 |
+
self.save_hyperparameters(logger=False)
|
52 |
+
|
53 |
+
self.n_vocab = n_vocab
|
54 |
+
self.n_spks = n_spks
|
55 |
+
self.spk_emb_dim = spk_emb_dim
|
56 |
+
self.n_enc_channels = n_enc_channels
|
57 |
+
self.filter_channels = filter_channels
|
58 |
+
self.filter_channels_dp = filter_channels_dp
|
59 |
+
self.n_heads = n_heads
|
60 |
+
self.n_enc_layers = n_enc_layers
|
61 |
+
self.enc_kernel = enc_kernel
|
62 |
+
self.enc_dropout = enc_dropout
|
63 |
+
self.window_size = window_size
|
64 |
+
self.n_feats = n_feats
|
65 |
+
self.n_motions = n_motions
|
66 |
+
self.dec_dim = dec_dim
|
67 |
+
self.beta_min = beta_min
|
68 |
+
self.beta_max = beta_max
|
69 |
+
self.pe_scale = pe_scale
|
70 |
+
self.generate_motion = not only_speech
|
71 |
+
self.motion_reduction_factor = motion_reduction_factor
|
72 |
+
self.out_size = out_size
|
73 |
+
self.mu_diffusion_channels = motion_decoder_channels
|
74 |
+
|
75 |
+
if n_spks > 1:
|
76 |
+
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
77 |
+
self.encoder = TextEncoder(n_vocab, n_feats, n_enc_channels,
|
78 |
+
filter_channels, filter_channels_dp, n_heads,
|
79 |
+
n_enc_layers, enc_kernel, enc_dropout, window_size, encoder_type=encoder_type)
|
80 |
+
self.decoder = Diffusion(n_feats, dec_dim, n_spks, spk_emb_dim, beta_min, beta_max, pe_scale)
|
81 |
+
|
82 |
+
if self.generate_motion:
|
83 |
+
self.motion_prior_loss = mu_motion_encoder_params.pop('prior_loss', True)
|
84 |
+
self.mu_motion_encoder = MuMotionEncoder(
|
85 |
+
input_channels=n_feats,
|
86 |
+
output_channels=n_motions,
|
87 |
+
**mu_motion_encoder_params
|
88 |
+
)
|
89 |
+
self.decoder_motion = Diffusion_Motion(
|
90 |
+
in_channels=n_motions,
|
91 |
+
motion_decoder_channels=motion_decoder_channels,
|
92 |
+
beta_min=beta_min,
|
93 |
+
beta_max=beta_max,
|
94 |
+
)
|
95 |
+
|
96 |
+
self.update_data_statistics(data_statistics)
|
97 |
+
|
98 |
+
def update_data_statistics(self, data_statistics):
|
99 |
+
if data_statistics is None:
|
100 |
+
data_statistics = {
|
101 |
+
'mel_mean': 0.0,
|
102 |
+
'mel_std': 1.0,
|
103 |
+
'motion_mean': 0.0,
|
104 |
+
'motion_std': 1.0,
|
105 |
+
}
|
106 |
+
|
107 |
+
self.register_buffer('mel_mean', torch.tensor(data_statistics['mel_mean']))
|
108 |
+
self.register_buffer('mel_std', torch.tensor(data_statistics['mel_std']))
|
109 |
+
self.register_buffer('motion_mean', torch.tensor(data_statistics['motion_mean']))
|
110 |
+
self.register_buffer('motion_std', torch.tensor(data_statistics['motion_std']))
|
111 |
+
|
112 |
+
@torch.inference_mode()
|
113 |
+
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, length_scale=1.0):
|
114 |
+
"""
|
115 |
+
Generates mel-spectrogram from text. Returns:
|
116 |
+
1. encoder outputs
|
117 |
+
2. decoder outputs
|
118 |
+
3. generated alignment
|
119 |
+
|
120 |
+
Args:
|
121 |
+
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
122 |
+
x_lengths (torch.Tensor): lengths of texts in batch.
|
123 |
+
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
|
124 |
+
temperature (float, optional): controls variance of terminal distribution.
|
125 |
+
stoc (bool, optional): flag that adds stochastic term to the decoder sampler.
|
126 |
+
Usually, does not provide synthesis improvements.
|
127 |
+
length_scale (float, optional): controls speech pace.
|
128 |
+
Increase value to slow down generated speech and vice versa.
|
129 |
+
"""
|
130 |
+
if isinstance(n_timesteps, dict):
|
131 |
+
n_timestep_mel = n_timesteps['mel']
|
132 |
+
n_timestep_motion = n_timesteps['motion']
|
133 |
+
else:
|
134 |
+
n_timestep_mel = n_timesteps
|
135 |
+
n_timestep_motion = n_timesteps
|
136 |
+
|
137 |
+
if isinstance(temperature, dict):
|
138 |
+
temperature_mel = temperature['mel']
|
139 |
+
temperature_motion = temperature['motion']
|
140 |
+
else:
|
141 |
+
temperature_mel = temperature
|
142 |
+
temperature_motion = temperature
|
143 |
+
|
144 |
+
if self.n_spks > 1:
|
145 |
+
# Get speaker embedding
|
146 |
+
spk = self.spk_emb(spk)
|
147 |
+
|
148 |
+
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
149 |
+
mu_x, logw, x_mask = self.encoder(x, x_lengths, spk)
|
150 |
+
|
151 |
+
w = torch.exp(logw) * x_mask
|
152 |
+
w_ceil = torch.ceil(w) * length_scale
|
153 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
154 |
+
y_max_length = int(y_lengths.max())
|
155 |
+
y_max_length_ = fix_len_compatibility(y_max_length)
|
156 |
+
|
157 |
+
# Using obtained durations `w` construct alignment map `attn`
|
158 |
+
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
|
159 |
+
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
160 |
+
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
161 |
+
|
162 |
+
# Align encoded text and get mu_y
|
163 |
+
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
164 |
+
mu_y = mu_y.transpose(1, 2)
|
165 |
+
encoder_outputs = mu_y[:, :, :y_max_length]
|
166 |
+
|
167 |
+
|
168 |
+
# Sample latent representation from terminal distribution N(mu_y, I)
|
169 |
+
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature_mel
|
170 |
+
# Generate sample by performing reverse dynamics
|
171 |
+
decoder_outputs = self.decoder(z, y_mask, mu_y, n_timestep_mel, stoc, spk)
|
172 |
+
decoder_outputs = decoder_outputs[:, :, :y_max_length]
|
173 |
+
|
174 |
+
if self.generate_motion:
|
175 |
+
mu_y_motion = mu_y[:, :, ::self.motion_reduction_factor]
|
176 |
+
y_motion_mask = y_mask[:, :, ::self.motion_reduction_factor]
|
177 |
+
mu_y_motion = self.mu_motion_encoder(mu_y_motion, y_motion_mask)
|
178 |
+
encoder_outputs_motion = mu_y_motion[:, :, :y_max_length]
|
179 |
+
# sample latent representation from terminal distribution N(mu_y_motion, I)
|
180 |
+
z_motion = mu_y_motion + torch.randn_like(mu_y_motion, device=mu_y_motion.device) / temperature_motion
|
181 |
+
# Generate sample by performing reverse dynamics
|
182 |
+
decoder_outputs_motion = self.decoder_motion(z_motion, y_motion_mask, mu_y_motion, n_timestep_motion, stoc, spk)
|
183 |
+
decoder_outputs_motion = decoder_outputs_motion[:, :, :y_max_length]
|
184 |
+
else:
|
185 |
+
decoder_outputs_motion = None
|
186 |
+
encoder_outputs_motion = None
|
187 |
+
|
188 |
+
return {
|
189 |
+
'encoder_outputs_mel': encoder_outputs,
|
190 |
+
'decoder_outputs_mel': decoder_outputs,
|
191 |
+
'encoder_outputs_motion': encoder_outputs_motion,
|
192 |
+
'decoder_outputs_motion': decoder_outputs_motion,
|
193 |
+
'attn': attn[:, :, :y_max_length],
|
194 |
+
'mel': denormalize(decoder_outputs, self.mel_mean, self.mel_std),
|
195 |
+
'motion': denormalize(decoder_outputs_motion, self.motion_mean, self.motion_std) if self.generate_motion else None,
|
196 |
+
}
|
197 |
+
|
198 |
+
def forward(self, x, x_lengths, y, y_lengths, y_motion, spk=None, out_size=None):
|
199 |
+
"""
|
200 |
+
Computes 3 losses:
|
201 |
+
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
202 |
+
2. prior loss: loss between mel-spectrogram and encoder outputs.
|
203 |
+
3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
207 |
+
x_lengths (torch.Tensor): lengths of texts in batch.
|
208 |
+
y (torch.Tensor): batch of corresponding mel-spectrograms.
|
209 |
+
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
|
210 |
+
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
|
211 |
+
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
|
212 |
+
"""
|
213 |
+
if self.n_spks > 1:
|
214 |
+
# Get speaker embedding
|
215 |
+
spk = self.spk_emb(spk)
|
216 |
+
|
217 |
+
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
218 |
+
mu_x, logw, x_mask = self.encoder(x, x_lengths, spk)
|
219 |
+
y_max_length = y.shape[-1]
|
220 |
+
|
221 |
+
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
|
222 |
+
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
223 |
+
|
224 |
+
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
225 |
+
with torch.no_grad():
|
226 |
+
const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
227 |
+
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
|
228 |
+
y_square = torch.matmul(factor.transpose(1, 2), y ** 2)
|
229 |
+
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
|
230 |
+
mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1)
|
231 |
+
log_prior = y_square - y_mu_double + mu_square + const
|
232 |
+
|
233 |
+
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
|
234 |
+
attn = attn.detach()
|
235 |
+
|
236 |
+
# Compute loss between predicted log-scaled durations and those obtained from MAS
|
237 |
+
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
|
238 |
+
dur_loss = duration_loss(logw, logw_, x_lengths)
|
239 |
+
|
240 |
+
# Cut a small segment of mel-spectrogram in order to increase batch size
|
241 |
+
if not isinstance(out_size, type(None)):
|
242 |
+
max_offset = (y_lengths - out_size).clamp(0) # cut a random segment of size `out_size` from each sample in batch max_offset: [758, 160, 773]
|
243 |
+
offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) # offset ranges for each sample in batch offset_ranges: [(0, 758), (0, 160), (0, 773)]
|
244 |
+
out_offset = torch.LongTensor([
|
245 |
+
torch.tensor(random.choice(range(start, end)) if end > start else 0)
|
246 |
+
for start, end in offset_ranges
|
247 |
+
]).to(y_lengths)
|
248 |
+
attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
|
249 |
+
y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
|
250 |
+
|
251 |
+
if self.generate_motion:
|
252 |
+
y_motion_cut = torch.zeros(y_motion.shape[0], self.n_motions, out_size, dtype=y_motion.dtype, device=y_motion.device)
|
253 |
+
|
254 |
+
y_cut_lengths = []
|
255 |
+
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
|
256 |
+
y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
|
257 |
+
y_cut_lengths.append(y_cut_length)
|
258 |
+
cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
|
259 |
+
y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
|
260 |
+
if self.generate_motion:
|
261 |
+
y_motion_cut[i, :, :y_cut_length] = y_motion[i, :, cut_lower:cut_upper]
|
262 |
+
|
263 |
+
attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
|
264 |
+
y_cut_lengths = torch.LongTensor(y_cut_lengths)
|
265 |
+
y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
|
266 |
+
|
267 |
+
attn = attn_cut
|
268 |
+
y = y_cut
|
269 |
+
if self.generate_motion:
|
270 |
+
y_motion = y_motion_cut
|
271 |
+
|
272 |
+
y_mask = y_cut_mask
|
273 |
+
|
274 |
+
# Align encoded text with mel-spectrogram and get mu_y segment
|
275 |
+
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
276 |
+
mu_y = mu_y.transpose(1, 2)
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
# Compute loss of score-based decoder
|
281 |
+
diff_loss, xt = self.decoder.compute_loss(y, y_mask, mu_y, spk)
|
282 |
+
if self.generate_motion:
|
283 |
+
# Reduce motion features
|
284 |
+
mu_y_motion = mu_y[:, :, ::self.motion_reduction_factor]
|
285 |
+
y_motion_mask = y_mask[:, :, ::self.motion_reduction_factor]
|
286 |
+
y_motion = y_motion[:, :, ::self.motion_reduction_factor]
|
287 |
+
|
288 |
+
mu_y_motion = self.mu_motion_encoder(mu_y_motion, y_motion_mask)
|
289 |
+
diff_loss_motion, xt_motion = self.decoder_motion.compute_loss(y_motion, y_motion_mask, mu_y_motion, spk)
|
290 |
+
else:
|
291 |
+
diff_loss_motion = 0
|
292 |
+
|
293 |
+
# Compute loss between aligned encoder outputs and mel-spectrogram
|
294 |
+
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
295 |
+
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
296 |
+
|
297 |
+
if self.generate_motion and self.motion_prior_loss:
|
298 |
+
prior_loss_motion = torch.sum(0.5 * ((y_motion - mu_y_motion) ** 2 + math.log(2 * math.pi)) * y_motion_mask)
|
299 |
+
prior_loss_motion = prior_loss_motion / (torch.sum(y_motion_mask) * self.n_motions)
|
300 |
+
else:
|
301 |
+
prior_loss_motion = 0
|
302 |
+
|
303 |
+
return dur_loss, prior_loss + prior_loss_motion, diff_loss + diff_loss_motion
|
304 |
+
|
305 |
+
|
306 |
+
def configure_optimizers(self) -> Any:
|
307 |
+
optimizer = self.hparams.optimizer(params=self.parameters())
|
308 |
+
return {'optimizer': optimizer}
|
309 |
+
|
310 |
+
def get_losses(self, batch):
|
311 |
+
pass
|
312 |
+
x, x_lengths = batch['x'], batch['x_lengths']
|
313 |
+
y, y_lengths = batch['y'], batch['y_lengths']
|
314 |
+
y_motion = batch['y_motion']
|
315 |
+
dur_loss, prior_loss, diff_loss = self(x, x_lengths, y, y_lengths, y_motion, out_size=self.out_size)
|
316 |
+
return {
|
317 |
+
'dur_loss': dur_loss,
|
318 |
+
'prior_loss': prior_loss,
|
319 |
+
'diff_loss': diff_loss,
|
320 |
+
}
|
321 |
+
|
322 |
+
|
323 |
+
|
324 |
+
def training_step(self, batch: Any, batch_idx: int):
|
325 |
+
loss_dict = self.get_losses(batch)
|
326 |
+
self.log('step', float(self.global_step), on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
327 |
+
|
328 |
+
self.log('sub_loss/train_dur_loss', loss_dict['dur_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
329 |
+
self.log('sub_loss/train_prior_loss', loss_dict['prior_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
330 |
+
self.log('sub_loss/train_diff_loss', loss_dict['diff_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
331 |
+
|
332 |
+
total_loss = sum(loss_dict.values())
|
333 |
+
self.log('loss/train', total_loss, on_step=True, on_epoch=True, logger=True, prog_bar=True, sync_dist=True)
|
334 |
+
|
335 |
+
return {'loss': total_loss, 'log': loss_dict }
|
336 |
+
|
337 |
+
def validation_step(self, batch: Any, batch_idx: int):
|
338 |
+
loss_dict = self.get_losses(batch)
|
339 |
+
self.log('sub_loss/val_dur_loss', loss_dict['dur_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
340 |
+
self.log('sub_loss/val_prior_loss', loss_dict['prior_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
341 |
+
self.log('sub_loss/val_diff_loss', loss_dict['diff_loss'], on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
342 |
+
|
343 |
+
total_loss = sum(loss_dict.values())
|
344 |
+
self.log('loss/val', total_loss, on_step=True, on_epoch=True, logger=True, prog_bar=True, sync_dist=True)
|
345 |
+
|
346 |
+
return total_loss
|
347 |
+
|
348 |
+
def on_validation_end(self) -> None:
|
349 |
+
if self.trainer.is_global_zero:
|
350 |
+
one_batch = next(iter(self.trainer.val_dataloaders))
|
351 |
+
if self.current_epoch == 0:
|
352 |
+
log.debug("Plotting original samples")
|
353 |
+
for i in range(4):
|
354 |
+
y = one_batch['y'][i].unsqueeze(0).to(self.device)
|
355 |
+
y_motion = one_batch['y_motion'][i].unsqueeze(0).to(self.device)
|
356 |
+
self.logger.experiment.add_image(f'original/mel_{i}', plot_tensor(y.squeeze().cpu()), self.current_epoch, dataformats='HWC')
|
357 |
+
if self.generate_motion:
|
358 |
+
self.logger.experiment.add_image(f'original/mel_{i}', plot_tensor(y_motion.squeeze().cpu()), self.current_epoch, dataformats='HWC')
|
359 |
+
|
360 |
+
log.debug(f'Synthesising...')
|
361 |
+
for i in range(4):
|
362 |
+
x = one_batch['x'][i].unsqueeze(0).to(self.device)
|
363 |
+
x_lengths = one_batch['x_lengths'][i].unsqueeze(0).to(self.device)
|
364 |
+
output = self.synthesise(x, x_lengths, n_timesteps=20)
|
365 |
+
y_enc, y_dec = output['encoder_outputs_mel'], output['decoder_outputs_mel']
|
366 |
+
y_motion_enc, y_motion_dec, attn = output['encoder_outputs_motion'], output['decoder_outputs_motion'], output['attn']
|
367 |
+
self.logger.experiment.add_image(f'generated_enc/{i}', plot_tensor(y_enc.squeeze().cpu()), self.current_epoch, dataformats='HWC')
|
368 |
+
self.logger.experiment.add_image(f'generated_dec/{i}', plot_tensor(y_dec.squeeze().cpu()), self.current_epoch, dataformats='HWC')
|
369 |
+
if self.generate_motion:
|
370 |
+
self.logger.experiment.add_image(f'generated_enc_motion/{i}', plot_tensor(y_motion_enc.squeeze().cpu()), self.current_epoch, dataformats='HWC')
|
371 |
+
self.logger.experiment.add_image(f'generated_dec_motion/{i}', plot_tensor(y_motion_dec.squeeze().cpu()), self.current_epoch, dataformats='HWC')
|
372 |
+
|
373 |
+
self.logger.experiment.add_image(f'alignment/{i}', plot_tensor(attn.squeeze().cpu()), self.current_epoch, dataformats='HWC')
|
374 |
+
|
375 |
+
|
376 |
+
|
diff_ttsg/models/mnist_module.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from lightning import LightningModule
|
5 |
+
from torchmetrics import MaxMetric, MeanMetric
|
6 |
+
from torchmetrics.classification.accuracy import Accuracy
|
7 |
+
|
8 |
+
|
9 |
+
class MNISTLitModule(LightningModule):
|
10 |
+
"""Example of LightningModule for MNIST classification.
|
11 |
+
|
12 |
+
A LightningModule organizes your PyTorch code into 6 sections:
|
13 |
+
- Initialization (__init__)
|
14 |
+
- Train Loop (training_step)
|
15 |
+
- Validation loop (validation_step)
|
16 |
+
- Test loop (test_step)
|
17 |
+
- Prediction Loop (predict_step)
|
18 |
+
- Optimizers and LR Schedulers (configure_optimizers)
|
19 |
+
|
20 |
+
Docs:
|
21 |
+
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
net: torch.nn.Module,
|
27 |
+
optimizer: torch.optim.Optimizer,
|
28 |
+
scheduler: torch.optim.lr_scheduler,
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
# this line allows to access init params with 'self.hparams' attribute
|
33 |
+
# also ensures init params will be stored in ckpt
|
34 |
+
self.save_hyperparameters(logger=False)
|
35 |
+
|
36 |
+
self.net = net
|
37 |
+
|
38 |
+
# loss function
|
39 |
+
self.criterion = torch.nn.CrossEntropyLoss()
|
40 |
+
|
41 |
+
# metric objects for calculating and averaging accuracy across batches
|
42 |
+
self.train_acc = Accuracy(task="multiclass", num_classes=10)
|
43 |
+
self.val_acc = Accuracy(task="multiclass", num_classes=10)
|
44 |
+
self.test_acc = Accuracy(task="multiclass", num_classes=10)
|
45 |
+
|
46 |
+
# for averaging loss across batches
|
47 |
+
self.train_loss = MeanMetric()
|
48 |
+
self.val_loss = MeanMetric()
|
49 |
+
self.test_loss = MeanMetric()
|
50 |
+
|
51 |
+
# for tracking best so far validation accuracy
|
52 |
+
self.val_acc_best = MaxMetric()
|
53 |
+
|
54 |
+
def forward(self, x: torch.Tensor):
|
55 |
+
return self.net(x)
|
56 |
+
|
57 |
+
def on_train_start(self):
|
58 |
+
# by default lightning executes validation step sanity checks before training starts,
|
59 |
+
# so it's worth to make sure validation metrics don't store results from these checks
|
60 |
+
self.val_loss.reset()
|
61 |
+
self.val_acc.reset()
|
62 |
+
self.val_acc_best.reset()
|
63 |
+
|
64 |
+
def model_step(self, batch: Any):
|
65 |
+
x, y = batch
|
66 |
+
logits = self.forward(x)
|
67 |
+
loss = self.criterion(logits, y)
|
68 |
+
preds = torch.argmax(logits, dim=1)
|
69 |
+
return loss, preds, y
|
70 |
+
|
71 |
+
def training_step(self, batch: Any, batch_idx: int):
|
72 |
+
loss, preds, targets = self.model_step(batch)
|
73 |
+
|
74 |
+
# update and log metrics
|
75 |
+
self.train_loss(loss)
|
76 |
+
self.train_acc(preds, targets)
|
77 |
+
self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
|
78 |
+
self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
|
79 |
+
|
80 |
+
# return loss or backpropagation will fail
|
81 |
+
return loss
|
82 |
+
|
83 |
+
def on_train_epoch_end(self):
|
84 |
+
pass
|
85 |
+
|
86 |
+
def validation_step(self, batch: Any, batch_idx: int):
|
87 |
+
loss, preds, targets = self.model_step(batch)
|
88 |
+
|
89 |
+
# update and log metrics
|
90 |
+
self.val_loss(loss)
|
91 |
+
self.val_acc(preds, targets)
|
92 |
+
self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
|
93 |
+
self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
|
94 |
+
|
95 |
+
def on_validation_epoch_end(self):
|
96 |
+
acc = self.val_acc.compute() # get current val acc
|
97 |
+
self.val_acc_best(acc) # update best so far val acc
|
98 |
+
# log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
|
99 |
+
# otherwise metric would be reset by lightning after each epoch
|
100 |
+
self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)
|
101 |
+
|
102 |
+
def test_step(self, batch: Any, batch_idx: int):
|
103 |
+
loss, preds, targets = self.model_step(batch)
|
104 |
+
|
105 |
+
# update and log metrics
|
106 |
+
self.test_loss(loss)
|
107 |
+
self.test_acc(preds, targets)
|
108 |
+
self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
|
109 |
+
self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
|
110 |
+
|
111 |
+
def on_test_epoch_end(self):
|
112 |
+
pass
|
113 |
+
|
114 |
+
def configure_optimizers(self):
|
115 |
+
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
|
116 |
+
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
117 |
+
|
118 |
+
Examples:
|
119 |
+
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
|
120 |
+
"""
|
121 |
+
optimizer = self.hparams.optimizer(params=self.parameters())
|
122 |
+
if self.hparams.scheduler is not None:
|
123 |
+
scheduler = self.hparams.scheduler(optimizer=optimizer)
|
124 |
+
return {
|
125 |
+
"optimizer": optimizer,
|
126 |
+
"lr_scheduler": {
|
127 |
+
"scheduler": scheduler,
|
128 |
+
"monitor": "val/loss",
|
129 |
+
"interval": "epoch",
|
130 |
+
"frequency": 1,
|
131 |
+
},
|
132 |
+
}
|
133 |
+
return {"optimizer": optimizer}
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
_ = MNISTLitModule(None, None, None)
|
diff_ttsg/resources/cmu_dictionary
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diff_ttsg/text/LICENSE
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CMUdict
|
2 |
+
-------
|
3 |
+
|
4 |
+
CMUdict (the Carnegie Mellon Pronouncing Dictionary) is a free
|
5 |
+
pronouncing dictionary of English, suitable for uses in speech
|
6 |
+
technology and is maintained by the Speech Group in the School of
|
7 |
+
Computer Science at Carnegie Mellon University.
|
8 |
+
|
9 |
+
The Carnegie Mellon Speech Group does not guarantee the accuracy of
|
10 |
+
this dictionary, nor its suitability for any specific purpose. In
|
11 |
+
fact, we expect a number of errors, omissions and inconsistencies to
|
12 |
+
remain in the dictionary. We intend to continually update the
|
13 |
+
dictionary by correction existing entries and by adding new ones. From
|
14 |
+
time to time a new major version will be released.
|
15 |
+
|
16 |
+
We welcome input from users: Please send email to Alex Rudnicky
|
17 |
+
([email protected]).
|
18 |
+
|
19 |
+
The Carnegie Mellon Pronouncing Dictionary, in its current and
|
20 |
+
previous versions is Copyright (C) 1993-2014 by Carnegie Mellon
|
21 |
+
University. Use of this dictionary for any research or commercial
|
22 |
+
purpose is completely unrestricted. If you make use of or
|
23 |
+
redistribute this material we request that you acknowledge its
|
24 |
+
origin in your descriptions.
|
25 |
+
|
26 |
+
If you add words to or correct words in your version of this
|
27 |
+
dictionary, we would appreciate it if you could send these additions
|
28 |
+
and corrections to us ([email protected]) for consideration in a
|
29 |
+
subsequent version. All submissions will be reviewed and approved by
|
30 |
+
the current maintainer, Alex Rudnicky at Carnegie Mellon.
|
diff_ttsg/text/__init__.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
from diff_ttsg.text import cleaners
|
6 |
+
from diff_ttsg.text.symbols import symbols
|
7 |
+
|
8 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
9 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
10 |
+
|
11 |
+
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
12 |
+
|
13 |
+
|
14 |
+
def get_arpabet(word, dictionary):
|
15 |
+
word_arpabet = dictionary.lookup(word)
|
16 |
+
if word_arpabet is not None:
|
17 |
+
return "{" + word_arpabet[0] + "}"
|
18 |
+
else:
|
19 |
+
return word
|
20 |
+
|
21 |
+
|
22 |
+
def text_to_sequence(text, cleaner_names=["english_cleaners"], dictionary=None):
|
23 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
24 |
+
|
25 |
+
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
26 |
+
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
27 |
+
|
28 |
+
Args:
|
29 |
+
text: string to convert to a sequence
|
30 |
+
cleaner_names: names of the cleaner functions to run the text through
|
31 |
+
dictionary: arpabet class with arpabet dictionary
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
List of integers corresponding to the symbols in the text
|
35 |
+
'''
|
36 |
+
sequence = []
|
37 |
+
space = _symbols_to_sequence(' ')
|
38 |
+
# Check for curly braces and treat their contents as ARPAbet:
|
39 |
+
while len(text):
|
40 |
+
m = _curly_re.match(text)
|
41 |
+
if not m:
|
42 |
+
clean_text = _clean_text(text, cleaner_names)
|
43 |
+
if dictionary is not None:
|
44 |
+
clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")]
|
45 |
+
for i in range(len(clean_text)):
|
46 |
+
t = clean_text[i]
|
47 |
+
if t.startswith("{"):
|
48 |
+
sequence += _arpabet_to_sequence(t[1:-1])
|
49 |
+
else:
|
50 |
+
sequence += _symbols_to_sequence(t)
|
51 |
+
sequence += space
|
52 |
+
else:
|
53 |
+
sequence += _symbols_to_sequence(clean_text)
|
54 |
+
break
|
55 |
+
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
56 |
+
sequence += _arpabet_to_sequence(m.group(2))
|
57 |
+
text = m.group(3)
|
58 |
+
|
59 |
+
# remove trailing space
|
60 |
+
if dictionary is not None:
|
61 |
+
sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
|
62 |
+
return sequence
|
63 |
+
|
64 |
+
|
65 |
+
def sequence_to_text(sequence):
|
66 |
+
'''Converts a sequence of IDs back to a string'''
|
67 |
+
result = ''
|
68 |
+
for symbol_id in sequence:
|
69 |
+
if symbol_id in _id_to_symbol:
|
70 |
+
s = _id_to_symbol[symbol_id]
|
71 |
+
# Enclose ARPAbet back in curly braces:
|
72 |
+
if len(s) > 1 and s[0] == '@':
|
73 |
+
s = '{%s}' % s[1:]
|
74 |
+
result += s
|
75 |
+
return result.replace('}{', ' ')
|
76 |
+
|
77 |
+
|
78 |
+
def _clean_text(text, cleaner_names):
|
79 |
+
for name in cleaner_names:
|
80 |
+
cleaner = getattr(cleaners, name)
|
81 |
+
if not cleaner:
|
82 |
+
raise Exception('Unknown cleaner: %s' % name)
|
83 |
+
text = cleaner(text)
|
84 |
+
return text
|
85 |
+
|
86 |
+
|
87 |
+
def _symbols_to_sequence(symbols):
|
88 |
+
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
89 |
+
|
90 |
+
|
91 |
+
def _arpabet_to_sequence(text):
|
92 |
+
return _symbols_to_sequence(['@' + s for s in text.split()])
|
93 |
+
|
94 |
+
|
95 |
+
def _should_keep_symbol(s):
|
96 |
+
return s in _symbol_to_id and s != '_' and s != '~'
|
diff_ttsg/text/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (3.41 kB). View file
|
|
diff_ttsg/text/__pycache__/cleaners.cpython-310.pyc
ADDED
Binary file (1.98 kB). View file
|
|
diff_ttsg/text/__pycache__/cmudict.cpython-310.pyc
ADDED
Binary file (2.22 kB). View file
|
|
diff_ttsg/text/__pycache__/numbers.cpython-310.pyc
ADDED
Binary file (2.22 kB). View file
|
|
diff_ttsg/text/__pycache__/symbols.cpython-310.pyc
ADDED
Binary file (604 Bytes). View file
|
|
diff_ttsg/text/cleaners.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
import re
|
4 |
+
from unidecode import unidecode
|
5 |
+
from .numbers import normalize_numbers
|
6 |
+
|
7 |
+
|
8 |
+
_whitespace_re = re.compile(r'\s+')
|
9 |
+
|
10 |
+
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
11 |
+
('mrs', 'misess'),
|
12 |
+
('mr', 'mister'),
|
13 |
+
('dr', 'doctor'),
|
14 |
+
('st', 'saint'),
|
15 |
+
('co', 'company'),
|
16 |
+
('jr', 'junior'),
|
17 |
+
('maj', 'major'),
|
18 |
+
('gen', 'general'),
|
19 |
+
('drs', 'doctors'),
|
20 |
+
('rev', 'reverend'),
|
21 |
+
('lt', 'lieutenant'),
|
22 |
+
('hon', 'honorable'),
|
23 |
+
('sgt', 'sergeant'),
|
24 |
+
('capt', 'captain'),
|
25 |
+
('esq', 'esquire'),
|
26 |
+
('ltd', 'limited'),
|
27 |
+
('col', 'colonel'),
|
28 |
+
('ft', 'fort'),
|
29 |
+
]]
|
30 |
+
|
31 |
+
|
32 |
+
def expand_abbreviations(text):
|
33 |
+
for regex, replacement in _abbreviations:
|
34 |
+
text = re.sub(regex, replacement, text)
|
35 |
+
return text
|
36 |
+
|
37 |
+
|
38 |
+
def expand_numbers(text):
|
39 |
+
return normalize_numbers(text)
|
40 |
+
|
41 |
+
|
42 |
+
def lowercase(text):
|
43 |
+
return text.lower()
|
44 |
+
|
45 |
+
|
46 |
+
def collapse_whitespace(text):
|
47 |
+
return re.sub(_whitespace_re, ' ', text)
|
48 |
+
|
49 |
+
|
50 |
+
def convert_to_ascii(text):
|
51 |
+
return unidecode(text)
|
52 |
+
|
53 |
+
|
54 |
+
def basic_cleaners(text):
|
55 |
+
text = lowercase(text)
|
56 |
+
text = collapse_whitespace(text)
|
57 |
+
return text
|
58 |
+
|
59 |
+
|
60 |
+
def transliteration_cleaners(text):
|
61 |
+
text = convert_to_ascii(text)
|
62 |
+
text = lowercase(text)
|
63 |
+
text = collapse_whitespace(text)
|
64 |
+
return text
|
65 |
+
|
66 |
+
|
67 |
+
def english_cleaners(text):
|
68 |
+
text = convert_to_ascii(text)
|
69 |
+
text = lowercase(text)
|
70 |
+
text = expand_numbers(text)
|
71 |
+
text = expand_abbreviations(text)
|
72 |
+
text = collapse_whitespace(text)
|
73 |
+
return text
|
diff_ttsg/text/cmudict.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
valid_symbols = [
|
7 |
+
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
|
8 |
+
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
|
9 |
+
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
|
10 |
+
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
|
11 |
+
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
|
12 |
+
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
|
13 |
+
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
|
14 |
+
]
|
15 |
+
|
16 |
+
_valid_symbol_set = set(valid_symbols)
|
17 |
+
|
18 |
+
|
19 |
+
class CMUDict:
|
20 |
+
def __init__(self, file_or_path, keep_ambiguous=True):
|
21 |
+
if isinstance(file_or_path, str):
|
22 |
+
with open(file_or_path, encoding='latin-1') as f:
|
23 |
+
entries = _parse_cmudict(f)
|
24 |
+
else:
|
25 |
+
entries = _parse_cmudict(file_or_path)
|
26 |
+
if not keep_ambiguous:
|
27 |
+
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
28 |
+
self._entries = entries
|
29 |
+
|
30 |
+
def __len__(self):
|
31 |
+
return len(self._entries)
|
32 |
+
|
33 |
+
def lookup(self, word):
|
34 |
+
return self._entries.get(word.upper())
|
35 |
+
|
36 |
+
|
37 |
+
_alt_re = re.compile(r'\([0-9]+\)')
|
38 |
+
|
39 |
+
|
40 |
+
def _parse_cmudict(file):
|
41 |
+
cmudict = {}
|
42 |
+
for line in file:
|
43 |
+
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
|
44 |
+
parts = line.split(' ')
|
45 |
+
word = re.sub(_alt_re, '', parts[0])
|
46 |
+
pronunciation = _get_pronunciation(parts[1])
|
47 |
+
if pronunciation:
|
48 |
+
if word in cmudict:
|
49 |
+
cmudict[word].append(pronunciation)
|
50 |
+
else:
|
51 |
+
cmudict[word] = [pronunciation]
|
52 |
+
return cmudict
|
53 |
+
|
54 |
+
|
55 |
+
def _get_pronunciation(s):
|
56 |
+
parts = s.strip().split(' ')
|
57 |
+
for part in parts:
|
58 |
+
if part not in _valid_symbol_set:
|
59 |
+
return None
|
60 |
+
return ' '.join(parts)
|
diff_ttsg/text/numbers.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
import inflect
|
4 |
+
import re
|
5 |
+
|
6 |
+
|
7 |
+
_inflect = inflect.engine()
|
8 |
+
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
9 |
+
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
10 |
+
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
11 |
+
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
12 |
+
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
13 |
+
_number_re = re.compile(r'[0-9]+')
|
14 |
+
|
15 |
+
|
16 |
+
def _remove_commas(m):
|
17 |
+
return m.group(1).replace(',', '')
|
18 |
+
|
19 |
+
|
20 |
+
def _expand_decimal_point(m):
|
21 |
+
return m.group(1).replace('.', ' point ')
|
22 |
+
|
23 |
+
|
24 |
+
def _expand_dollars(m):
|
25 |
+
match = m.group(1)
|
26 |
+
parts = match.split('.')
|
27 |
+
if len(parts) > 2:
|
28 |
+
return match + ' dollars'
|
29 |
+
dollars = int(parts[0]) if parts[0] else 0
|
30 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
31 |
+
if dollars and cents:
|
32 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
33 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
34 |
+
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
35 |
+
elif dollars:
|
36 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
37 |
+
return '%s %s' % (dollars, dollar_unit)
|
38 |
+
elif cents:
|
39 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
40 |
+
return '%s %s' % (cents, cent_unit)
|
41 |
+
else:
|
42 |
+
return 'zero dollars'
|
43 |
+
|
44 |
+
|
45 |
+
def _expand_ordinal(m):
|
46 |
+
return _inflect.number_to_words(m.group(0))
|
47 |
+
|
48 |
+
|
49 |
+
def _expand_number(m):
|
50 |
+
num = int(m.group(0))
|
51 |
+
if num > 1000 and num < 3000:
|
52 |
+
if num == 2000:
|
53 |
+
return 'two thousand'
|
54 |
+
elif num > 2000 and num < 2010:
|
55 |
+
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
56 |
+
elif num % 100 == 0:
|
57 |
+
return _inflect.number_to_words(num // 100) + ' hundred'
|
58 |
+
else:
|
59 |
+
return _inflect.number_to_words(num, andword='', zero='oh',
|
60 |
+
group=2).replace(', ', ' ')
|
61 |
+
else:
|
62 |
+
return _inflect.number_to_words(num, andword='')
|
63 |
+
|
64 |
+
|
65 |
+
def normalize_numbers(text):
|
66 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
67 |
+
text = re.sub(_pounds_re, r'\1 pounds', text)
|
68 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
69 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
70 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
71 |
+
text = re.sub(_number_re, _expand_number, text)
|
72 |
+
return text
|