Spaces:
Running
on
Zero
Running
on
Zero
mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- .github/workflows/pre-commit.yaml +14 -0
- .pre-commit-config.yaml +14 -0
- README_REPO.md +20 -0
- app.py +109 -73
- finetune-cli.py +61 -42
- finetune_gradio.py +424 -358
- inference-cli.py +21 -30
- model/__init__.py +3 -0
- model/backbones/dit.py +47 -40
- model/backbones/mmdit.py +37 -25
- model/backbones/unett.py +66 -46
- model/cfm.py +67 -59
- model/dataset.py +98 -75
- model/ecapa_tdnn.py +97 -35
- model/modules.py +113 -107
- model/trainer.py +123 -82
- model/utils.py +178 -136
- model/utils_infer.py +33 -34
- ruff.toml +10 -0
- scripts/count_max_epoch.py +4 -3
- scripts/count_params_gflops.py +10 -6
- scripts/eval_infer_batch.py +59 -59
- scripts/eval_librispeech_test_clean.py +6 -4
- scripts/eval_seedtts_testset.py +10 -8
- scripts/prepare_csv_wavs.py +16 -10
- scripts/prepare_emilia.py +100 -16
- scripts/prepare_wenetspeech4tts.py +15 -12
- speech_edit.py +45 -39
- train.py +33 -35
.github/workflows/pre-commit.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pre-commit
|
2 |
+
|
3 |
+
on:
|
4 |
+
pull_request:
|
5 |
+
push:
|
6 |
+
branches: [main]
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
pre-commit:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
steps:
|
12 |
+
- uses: actions/checkout@v3
|
13 |
+
- uses: actions/setup-python@v3
|
14 |
+
- uses: pre-commit/[email protected]
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
3 |
+
# Ruff version.
|
4 |
+
rev: v0.7.0
|
5 |
+
hooks:
|
6 |
+
# Run the linter.
|
7 |
+
- id: ruff
|
8 |
+
args: [--fix]
|
9 |
+
# Run the formatter.
|
10 |
+
- id: ruff-format
|
11 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
12 |
+
rev: v2.3.0
|
13 |
+
hooks:
|
14 |
+
- id: check-yaml
|
README_REPO.md
CHANGED
@@ -43,6 +43,26 @@ pip install -r requirements.txt
|
|
43 |
docker build -t f5tts:v1 .
|
44 |
```
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
## Prepare Dataset
|
47 |
|
48 |
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
|
|
|
43 |
docker build -t f5tts:v1 .
|
44 |
```
|
45 |
|
46 |
+
### Development
|
47 |
+
|
48 |
+
When making a pull request, please use pre-commit to ensure code quality:
|
49 |
+
|
50 |
+
```bash
|
51 |
+
pip install pre-commit
|
52 |
+
pre-commit install
|
53 |
+
```
|
54 |
+
|
55 |
+
This will run linters and formatters automatically before each commit.
|
56 |
+
|
57 |
+
Manually run using:
|
58 |
+
|
59 |
+
```bash
|
60 |
+
pre-commit run --all-files
|
61 |
+
```
|
62 |
+
|
63 |
+
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
|
64 |
+
|
65 |
+
|
66 |
## Prepare Dataset
|
67 |
|
68 |
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
|
app.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import re
|
2 |
import tempfile
|
3 |
|
@@ -11,16 +14,19 @@ from pydub import AudioSegment
|
|
11 |
|
12 |
try:
|
13 |
import spaces
|
|
|
14 |
USING_SPACES = True
|
15 |
except ImportError:
|
16 |
USING_SPACES = False
|
17 |
|
|
|
18 |
def gpu_decorator(func):
|
19 |
if USING_SPACES:
|
20 |
return spaces.GPU(func)
|
21 |
else:
|
22 |
return func
|
23 |
|
|
|
24 |
from model import DiT, UNetT
|
25 |
from model.utils import (
|
26 |
save_spectrogram,
|
@@ -38,15 +44,18 @@ vocos = load_vocoder()
|
|
38 |
|
39 |
# load models
|
40 |
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
41 |
-
F5TTS_ema_model = load_model(
|
|
|
|
|
42 |
|
43 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
44 |
-
E2TTS_ema_model = load_model(
|
|
|
|
|
45 |
|
46 |
|
47 |
@gpu_decorator
|
48 |
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
|
49 |
-
|
50 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=gr.Info)
|
51 |
|
52 |
if model == "F5-TTS":
|
@@ -54,7 +63,16 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_
|
|
54 |
elif model == "E2-TTS":
|
55 |
ema_model = E2TTS_ema_model
|
56 |
|
57 |
-
final_wave, final_sample_rate, combined_spectrogram = infer_process(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
# Remove silence
|
60 |
if remove_silence:
|
@@ -73,17 +91,19 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_
|
|
73 |
|
74 |
|
75 |
@gpu_decorator
|
76 |
-
def generate_podcast(
|
|
|
|
|
77 |
# Split the script into speaker blocks
|
78 |
speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
|
79 |
speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
|
80 |
-
|
81 |
generated_audio_segments = []
|
82 |
-
|
83 |
for i in range(0, len(speaker_blocks), 2):
|
84 |
speaker = speaker_blocks[i]
|
85 |
-
text = speaker_blocks[i+1].strip()
|
86 |
-
|
87 |
# Determine which speaker is talking
|
88 |
if speaker == speaker1_name:
|
89 |
ref_audio = ref_audio1
|
@@ -93,51 +113,52 @@ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name
|
|
93 |
ref_text = ref_text2
|
94 |
else:
|
95 |
continue # Skip if the speaker is neither speaker1 nor speaker2
|
96 |
-
|
97 |
# Generate audio for this block
|
98 |
audio, _ = infer(ref_audio, ref_text, text, model, remove_silence)
|
99 |
-
|
100 |
# Convert the generated audio to a numpy array
|
101 |
sr, audio_data = audio
|
102 |
-
|
103 |
# Save the audio data as a WAV file
|
104 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
105 |
sf.write(temp_file.name, audio_data, sr)
|
106 |
audio_segment = AudioSegment.from_wav(temp_file.name)
|
107 |
-
|
108 |
generated_audio_segments.append(audio_segment)
|
109 |
-
|
110 |
# Add a short pause between speakers
|
111 |
pause = AudioSegment.silent(duration=500) # 500ms pause
|
112 |
generated_audio_segments.append(pause)
|
113 |
-
|
114 |
# Concatenate all audio segments
|
115 |
final_podcast = sum(generated_audio_segments)
|
116 |
-
|
117 |
# Export the final podcast
|
118 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
119 |
podcast_path = temp_file.name
|
120 |
final_podcast.export(podcast_path, format="wav")
|
121 |
-
|
122 |
return podcast_path
|
123 |
|
|
|
124 |
def parse_speechtypes_text(gen_text):
|
125 |
# Pattern to find (Emotion)
|
126 |
-
pattern = r
|
127 |
|
128 |
# Split the text by the pattern
|
129 |
tokens = re.split(pattern, gen_text)
|
130 |
|
131 |
segments = []
|
132 |
|
133 |
-
current_emotion =
|
134 |
|
135 |
for i in range(len(tokens)):
|
136 |
if i % 2 == 0:
|
137 |
# This is text
|
138 |
text = tokens[i].strip()
|
139 |
if text:
|
140 |
-
segments.append({
|
141 |
else:
|
142 |
# This is emotion
|
143 |
emotion = tokens[i].strip()
|
@@ -158,9 +179,7 @@ with gr.Blocks() as app_tts:
|
|
158 |
gr.Markdown("# Batched TTS")
|
159 |
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
160 |
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
|
161 |
-
model_choice = gr.Radio(
|
162 |
-
choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
|
163 |
-
)
|
164 |
generate_btn = gr.Button("Synthesize", variant="primary")
|
165 |
with gr.Accordion("Advanced Settings", open=False):
|
166 |
ref_text_input = gr.Textbox(
|
@@ -206,23 +225,24 @@ with gr.Blocks() as app_tts:
|
|
206 |
],
|
207 |
outputs=[audio_output, spectrogram_output],
|
208 |
)
|
209 |
-
|
210 |
with gr.Blocks() as app_podcast:
|
211 |
gr.Markdown("# Podcast Generation")
|
212 |
speaker1_name = gr.Textbox(label="Speaker 1 Name")
|
213 |
ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
|
214 |
ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
|
215 |
-
|
216 |
speaker2_name = gr.Textbox(label="Speaker 2 Name")
|
217 |
ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
|
218 |
ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
|
219 |
-
|
220 |
-
script_input = gr.Textbox(
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
|
225 |
)
|
|
|
|
|
226 |
podcast_remove_silence = gr.Checkbox(
|
227 |
label="Remove Silences",
|
228 |
value=True,
|
@@ -230,8 +250,12 @@ with gr.Blocks() as app_podcast:
|
|
230 |
generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
|
231 |
podcast_output = gr.Audio(label="Generated Podcast")
|
232 |
|
233 |
-
def podcast_generation(
|
234 |
-
|
|
|
|
|
|
|
|
|
235 |
|
236 |
generate_podcast_btn.click(
|
237 |
podcast_generation,
|
@@ -249,23 +273,24 @@ with gr.Blocks() as app_podcast:
|
|
249 |
outputs=podcast_output,
|
250 |
)
|
251 |
|
|
|
252 |
def parse_emotional_text(gen_text):
|
253 |
# Pattern to find (Emotion)
|
254 |
-
pattern = r
|
255 |
|
256 |
# Split the text by the pattern
|
257 |
tokens = re.split(pattern, gen_text)
|
258 |
|
259 |
segments = []
|
260 |
|
261 |
-
current_emotion =
|
262 |
|
263 |
for i in range(len(tokens)):
|
264 |
if i % 2 == 0:
|
265 |
# This is text
|
266 |
text = tokens[i].strip()
|
267 |
if text:
|
268 |
-
segments.append({
|
269 |
else:
|
270 |
# This is emotion
|
271 |
emotion = tokens[i].strip()
|
@@ -273,6 +298,7 @@ def parse_emotional_text(gen_text):
|
|
273 |
|
274 |
return segments
|
275 |
|
|
|
276 |
with gr.Blocks() as app_emotional:
|
277 |
# New section for emotional generation
|
278 |
gr.Markdown(
|
@@ -287,13 +313,15 @@ with gr.Blocks() as app_emotional:
|
|
287 |
"""
|
288 |
)
|
289 |
|
290 |
-
gr.Markdown(
|
|
|
|
|
291 |
|
292 |
# Regular speech type (mandatory)
|
293 |
with gr.Row():
|
294 |
-
regular_name = gr.Textbox(value=
|
295 |
-
regular_audio = gr.Audio(label=
|
296 |
-
regular_ref_text = gr.Textbox(label=
|
297 |
|
298 |
# Additional speech types (up to 99 more)
|
299 |
max_speech_types = 100
|
@@ -304,9 +332,9 @@ with gr.Blocks() as app_emotional:
|
|
304 |
|
305 |
for i in range(max_speech_types - 1):
|
306 |
with gr.Row():
|
307 |
-
name_input = gr.Textbox(label=
|
308 |
-
audio_input = gr.Audio(label=
|
309 |
-
ref_text_input = gr.Textbox(label=
|
310 |
delete_btn = gr.Button("Delete", variant="secondary", visible=False)
|
311 |
speech_type_names.append(name_input)
|
312 |
speech_type_audios.append(audio_input)
|
@@ -351,7 +379,11 @@ with gr.Blocks() as app_emotional:
|
|
351 |
add_speech_type_btn.click(
|
352 |
add_speech_type_fn,
|
353 |
inputs=speech_type_count,
|
354 |
-
outputs=[speech_type_count]
|
|
|
|
|
|
|
|
|
355 |
)
|
356 |
|
357 |
# Function to delete a speech type
|
@@ -365,9 +397,9 @@ with gr.Blocks() as app_emotional:
|
|
365 |
|
366 |
for i in range(max_speech_types - 1):
|
367 |
if i == index:
|
368 |
-
name_updates.append(gr.update(visible=False, value=
|
369 |
audio_updates.append(gr.update(visible=False, value=None))
|
370 |
-
ref_text_updates.append(gr.update(visible=False, value=
|
371 |
delete_btn_updates.append(gr.update(visible=False))
|
372 |
else:
|
373 |
name_updates.append(gr.update())
|
@@ -386,16 +418,18 @@ with gr.Blocks() as app_emotional:
|
|
386 |
delete_btn.click(
|
387 |
delete_fn,
|
388 |
inputs=speech_type_count,
|
389 |
-
outputs=[speech_type_count]
|
|
|
|
|
|
|
|
|
390 |
)
|
391 |
|
392 |
# Text input for the prompt
|
393 |
gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
|
394 |
|
395 |
# Model choice
|
396 |
-
model_choice_emotional = gr.Radio(
|
397 |
-
choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
|
398 |
-
)
|
399 |
|
400 |
with gr.Accordion("Advanced Settings", open=False):
|
401 |
remove_silence_emotional = gr.Checkbox(
|
@@ -408,6 +442,7 @@ with gr.Blocks() as app_emotional:
|
|
408 |
|
409 |
# Output audio
|
410 |
audio_output_emotional = gr.Audio(label="Synthesized Audio")
|
|
|
411 |
@gpu_decorator
|
412 |
def generate_emotional_speech(
|
413 |
regular_audio,
|
@@ -417,37 +452,39 @@ with gr.Blocks() as app_emotional:
|
|
417 |
):
|
418 |
num_additional_speech_types = max_speech_types - 1
|
419 |
speech_type_names_list = args[:num_additional_speech_types]
|
420 |
-
speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
|
421 |
-
speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
|
422 |
model_choice = args[3 * num_additional_speech_types]
|
423 |
remove_silence = args[3 * num_additional_speech_types + 1]
|
424 |
|
425 |
# Collect the speech types and their audios into a dict
|
426 |
-
speech_types = {
|
427 |
|
428 |
-
for name_input, audio_input, ref_text_input in zip(
|
|
|
|
|
429 |
if name_input and audio_input:
|
430 |
-
speech_types[name_input] = {
|
431 |
|
432 |
# Parse the gen_text into segments
|
433 |
segments = parse_speechtypes_text(gen_text)
|
434 |
|
435 |
# For each segment, generate speech
|
436 |
generated_audio_segments = []
|
437 |
-
current_emotion =
|
438 |
|
439 |
for segment in segments:
|
440 |
-
emotion = segment[
|
441 |
-
text = segment[
|
442 |
|
443 |
if emotion in speech_types:
|
444 |
current_emotion = emotion
|
445 |
else:
|
446 |
# If emotion not available, default to Regular
|
447 |
-
current_emotion =
|
448 |
|
449 |
-
ref_audio = speech_types[current_emotion][
|
450 |
-
ref_text = speech_types[current_emotion].get(
|
451 |
|
452 |
# Generate speech for this segment
|
453 |
audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
|
@@ -469,7 +506,11 @@ with gr.Blocks() as app_emotional:
|
|
469 |
regular_audio,
|
470 |
regular_ref_text,
|
471 |
gen_text_input_emotional,
|
472 |
-
]
|
|
|
|
|
|
|
|
|
473 |
model_choice_emotional,
|
474 |
remove_silence_emotional,
|
475 |
],
|
@@ -477,11 +518,7 @@ with gr.Blocks() as app_emotional:
|
|
477 |
)
|
478 |
|
479 |
# Validation function to disable Generate button if speech types are missing
|
480 |
-
def validate_speech_types(
|
481 |
-
gen_text,
|
482 |
-
regular_name,
|
483 |
-
*args
|
484 |
-
):
|
485 |
num_additional_speech_types = max_speech_types - 1
|
486 |
speech_type_names_list = args[:num_additional_speech_types]
|
487 |
|
@@ -495,7 +532,7 @@ with gr.Blocks() as app_emotional:
|
|
495 |
|
496 |
# Parse the gen_text to get the speech types used
|
497 |
segments = parse_emotional_text(gen_text)
|
498 |
-
speech_types_in_text = set(segment[
|
499 |
|
500 |
# Check if all speech types in text are available
|
501 |
missing_speech_types = speech_types_in_text - speech_types_available
|
@@ -510,7 +547,7 @@ with gr.Blocks() as app_emotional:
|
|
510 |
gen_text_input_emotional.change(
|
511 |
validate_speech_types,
|
512 |
inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
|
513 |
-
outputs=generate_emotional_btn
|
514 |
)
|
515 |
with gr.Blocks() as app:
|
516 |
gr.Markdown(
|
@@ -531,6 +568,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
531 |
)
|
532 |
gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
|
533 |
|
|
|
534 |
@click.command()
|
535 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
536 |
@click.option("--host", "-H", default=None, help="Host to run the app on")
|
@@ -544,10 +582,8 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
544 |
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
|
545 |
def main(port, host, share, api):
|
546 |
global app
|
547 |
-
print(
|
548 |
-
app.queue(api_open=api).launch(
|
549 |
-
server_name=host, server_port=port, share=share, show_api=api
|
550 |
-
)
|
551 |
|
552 |
|
553 |
if __name__ == "__main__":
|
|
|
1 |
+
# ruff: noqa: E402
|
2 |
+
# Above allows ruff to ignore E402: module level import not at top of file
|
3 |
+
|
4 |
import re
|
5 |
import tempfile
|
6 |
|
|
|
14 |
|
15 |
try:
|
16 |
import spaces
|
17 |
+
|
18 |
USING_SPACES = True
|
19 |
except ImportError:
|
20 |
USING_SPACES = False
|
21 |
|
22 |
+
|
23 |
def gpu_decorator(func):
|
24 |
if USING_SPACES:
|
25 |
return spaces.GPU(func)
|
26 |
else:
|
27 |
return func
|
28 |
|
29 |
+
|
30 |
from model import DiT, UNetT
|
31 |
from model.utils import (
|
32 |
save_spectrogram,
|
|
|
44 |
|
45 |
# load models
|
46 |
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
47 |
+
F5TTS_ema_model = load_model(
|
48 |
+
DiT, F5TTS_model_cfg, str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
|
49 |
+
)
|
50 |
|
51 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
52 |
+
E2TTS_ema_model = load_model(
|
53 |
+
UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
54 |
+
)
|
55 |
|
56 |
|
57 |
@gpu_decorator
|
58 |
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
|
|
|
59 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=gr.Info)
|
60 |
|
61 |
if model == "F5-TTS":
|
|
|
63 |
elif model == "E2-TTS":
|
64 |
ema_model = E2TTS_ema_model
|
65 |
|
66 |
+
final_wave, final_sample_rate, combined_spectrogram = infer_process(
|
67 |
+
ref_audio,
|
68 |
+
ref_text,
|
69 |
+
gen_text,
|
70 |
+
ema_model,
|
71 |
+
cross_fade_duration=cross_fade_duration,
|
72 |
+
speed=speed,
|
73 |
+
show_info=gr.Info,
|
74 |
+
progress=gr.Progress(),
|
75 |
+
)
|
76 |
|
77 |
# Remove silence
|
78 |
if remove_silence:
|
|
|
91 |
|
92 |
|
93 |
@gpu_decorator
|
94 |
+
def generate_podcast(
|
95 |
+
script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, model, remove_silence
|
96 |
+
):
|
97 |
# Split the script into speaker blocks
|
98 |
speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
|
99 |
speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
|
100 |
+
|
101 |
generated_audio_segments = []
|
102 |
+
|
103 |
for i in range(0, len(speaker_blocks), 2):
|
104 |
speaker = speaker_blocks[i]
|
105 |
+
text = speaker_blocks[i + 1].strip()
|
106 |
+
|
107 |
# Determine which speaker is talking
|
108 |
if speaker == speaker1_name:
|
109 |
ref_audio = ref_audio1
|
|
|
113 |
ref_text = ref_text2
|
114 |
else:
|
115 |
continue # Skip if the speaker is neither speaker1 nor speaker2
|
116 |
+
|
117 |
# Generate audio for this block
|
118 |
audio, _ = infer(ref_audio, ref_text, text, model, remove_silence)
|
119 |
+
|
120 |
# Convert the generated audio to a numpy array
|
121 |
sr, audio_data = audio
|
122 |
+
|
123 |
# Save the audio data as a WAV file
|
124 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
125 |
sf.write(temp_file.name, audio_data, sr)
|
126 |
audio_segment = AudioSegment.from_wav(temp_file.name)
|
127 |
+
|
128 |
generated_audio_segments.append(audio_segment)
|
129 |
+
|
130 |
# Add a short pause between speakers
|
131 |
pause = AudioSegment.silent(duration=500) # 500ms pause
|
132 |
generated_audio_segments.append(pause)
|
133 |
+
|
134 |
# Concatenate all audio segments
|
135 |
final_podcast = sum(generated_audio_segments)
|
136 |
+
|
137 |
# Export the final podcast
|
138 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
139 |
podcast_path = temp_file.name
|
140 |
final_podcast.export(podcast_path, format="wav")
|
141 |
+
|
142 |
return podcast_path
|
143 |
|
144 |
+
|
145 |
def parse_speechtypes_text(gen_text):
|
146 |
# Pattern to find (Emotion)
|
147 |
+
pattern = r"\((.*?)\)"
|
148 |
|
149 |
# Split the text by the pattern
|
150 |
tokens = re.split(pattern, gen_text)
|
151 |
|
152 |
segments = []
|
153 |
|
154 |
+
current_emotion = "Regular"
|
155 |
|
156 |
for i in range(len(tokens)):
|
157 |
if i % 2 == 0:
|
158 |
# This is text
|
159 |
text = tokens[i].strip()
|
160 |
if text:
|
161 |
+
segments.append({"emotion": current_emotion, "text": text})
|
162 |
else:
|
163 |
# This is emotion
|
164 |
emotion = tokens[i].strip()
|
|
|
179 |
gr.Markdown("# Batched TTS")
|
180 |
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
181 |
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
|
182 |
+
model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
|
|
|
|
|
183 |
generate_btn = gr.Button("Synthesize", variant="primary")
|
184 |
with gr.Accordion("Advanced Settings", open=False):
|
185 |
ref_text_input = gr.Textbox(
|
|
|
225 |
],
|
226 |
outputs=[audio_output, spectrogram_output],
|
227 |
)
|
228 |
+
|
229 |
with gr.Blocks() as app_podcast:
|
230 |
gr.Markdown("# Podcast Generation")
|
231 |
speaker1_name = gr.Textbox(label="Speaker 1 Name")
|
232 |
ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
|
233 |
ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
|
234 |
+
|
235 |
speaker2_name = gr.Textbox(label="Speaker 2 Name")
|
236 |
ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
|
237 |
ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
|
238 |
+
|
239 |
+
script_input = gr.Textbox(
|
240 |
+
label="Podcast Script",
|
241 |
+
lines=10,
|
242 |
+
placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...",
|
|
|
243 |
)
|
244 |
+
|
245 |
+
podcast_model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
|
246 |
podcast_remove_silence = gr.Checkbox(
|
247 |
label="Remove Silences",
|
248 |
value=True,
|
|
|
250 |
generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
|
251 |
podcast_output = gr.Audio(label="Generated Podcast")
|
252 |
|
253 |
+
def podcast_generation(
|
254 |
+
script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence
|
255 |
+
):
|
256 |
+
return generate_podcast(
|
257 |
+
script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence
|
258 |
+
)
|
259 |
|
260 |
generate_podcast_btn.click(
|
261 |
podcast_generation,
|
|
|
273 |
outputs=podcast_output,
|
274 |
)
|
275 |
|
276 |
+
|
277 |
def parse_emotional_text(gen_text):
|
278 |
# Pattern to find (Emotion)
|
279 |
+
pattern = r"\((.*?)\)"
|
280 |
|
281 |
# Split the text by the pattern
|
282 |
tokens = re.split(pattern, gen_text)
|
283 |
|
284 |
segments = []
|
285 |
|
286 |
+
current_emotion = "Regular"
|
287 |
|
288 |
for i in range(len(tokens)):
|
289 |
if i % 2 == 0:
|
290 |
# This is text
|
291 |
text = tokens[i].strip()
|
292 |
if text:
|
293 |
+
segments.append({"emotion": current_emotion, "text": text})
|
294 |
else:
|
295 |
# This is emotion
|
296 |
emotion = tokens[i].strip()
|
|
|
298 |
|
299 |
return segments
|
300 |
|
301 |
+
|
302 |
with gr.Blocks() as app_emotional:
|
303 |
# New section for emotional generation
|
304 |
gr.Markdown(
|
|
|
313 |
"""
|
314 |
)
|
315 |
|
316 |
+
gr.Markdown(
|
317 |
+
"Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
|
318 |
+
)
|
319 |
|
320 |
# Regular speech type (mandatory)
|
321 |
with gr.Row():
|
322 |
+
regular_name = gr.Textbox(value="Regular", label="Speech Type Name", interactive=False)
|
323 |
+
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
|
324 |
+
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
|
325 |
|
326 |
# Additional speech types (up to 99 more)
|
327 |
max_speech_types = 100
|
|
|
332 |
|
333 |
for i in range(max_speech_types - 1):
|
334 |
with gr.Row():
|
335 |
+
name_input = gr.Textbox(label="Speech Type Name", visible=False)
|
336 |
+
audio_input = gr.Audio(label="Reference Audio", type="filepath", visible=False)
|
337 |
+
ref_text_input = gr.Textbox(label="Reference Text", lines=2, visible=False)
|
338 |
delete_btn = gr.Button("Delete", variant="secondary", visible=False)
|
339 |
speech_type_names.append(name_input)
|
340 |
speech_type_audios.append(audio_input)
|
|
|
379 |
add_speech_type_btn.click(
|
380 |
add_speech_type_fn,
|
381 |
inputs=speech_type_count,
|
382 |
+
outputs=[speech_type_count]
|
383 |
+
+ speech_type_names
|
384 |
+
+ speech_type_audios
|
385 |
+
+ speech_type_ref_texts
|
386 |
+
+ speech_type_delete_btns,
|
387 |
)
|
388 |
|
389 |
# Function to delete a speech type
|
|
|
397 |
|
398 |
for i in range(max_speech_types - 1):
|
399 |
if i == index:
|
400 |
+
name_updates.append(gr.update(visible=False, value=""))
|
401 |
audio_updates.append(gr.update(visible=False, value=None))
|
402 |
+
ref_text_updates.append(gr.update(visible=False, value=""))
|
403 |
delete_btn_updates.append(gr.update(visible=False))
|
404 |
else:
|
405 |
name_updates.append(gr.update())
|
|
|
418 |
delete_btn.click(
|
419 |
delete_fn,
|
420 |
inputs=speech_type_count,
|
421 |
+
outputs=[speech_type_count]
|
422 |
+
+ speech_type_names
|
423 |
+
+ speech_type_audios
|
424 |
+
+ speech_type_ref_texts
|
425 |
+
+ speech_type_delete_btns,
|
426 |
)
|
427 |
|
428 |
# Text input for the prompt
|
429 |
gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
|
430 |
|
431 |
# Model choice
|
432 |
+
model_choice_emotional = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
|
|
|
|
|
433 |
|
434 |
with gr.Accordion("Advanced Settings", open=False):
|
435 |
remove_silence_emotional = gr.Checkbox(
|
|
|
442 |
|
443 |
# Output audio
|
444 |
audio_output_emotional = gr.Audio(label="Synthesized Audio")
|
445 |
+
|
446 |
@gpu_decorator
|
447 |
def generate_emotional_speech(
|
448 |
regular_audio,
|
|
|
452 |
):
|
453 |
num_additional_speech_types = max_speech_types - 1
|
454 |
speech_type_names_list = args[:num_additional_speech_types]
|
455 |
+
speech_type_audios_list = args[num_additional_speech_types : 2 * num_additional_speech_types]
|
456 |
+
speech_type_ref_texts_list = args[2 * num_additional_speech_types : 3 * num_additional_speech_types]
|
457 |
model_choice = args[3 * num_additional_speech_types]
|
458 |
remove_silence = args[3 * num_additional_speech_types + 1]
|
459 |
|
460 |
# Collect the speech types and their audios into a dict
|
461 |
+
speech_types = {"Regular": {"audio": regular_audio, "ref_text": regular_ref_text}}
|
462 |
|
463 |
+
for name_input, audio_input, ref_text_input in zip(
|
464 |
+
speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
|
465 |
+
):
|
466 |
if name_input and audio_input:
|
467 |
+
speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
|
468 |
|
469 |
# Parse the gen_text into segments
|
470 |
segments = parse_speechtypes_text(gen_text)
|
471 |
|
472 |
# For each segment, generate speech
|
473 |
generated_audio_segments = []
|
474 |
+
current_emotion = "Regular"
|
475 |
|
476 |
for segment in segments:
|
477 |
+
emotion = segment["emotion"]
|
478 |
+
text = segment["text"]
|
479 |
|
480 |
if emotion in speech_types:
|
481 |
current_emotion = emotion
|
482 |
else:
|
483 |
# If emotion not available, default to Regular
|
484 |
+
current_emotion = "Regular"
|
485 |
|
486 |
+
ref_audio = speech_types[current_emotion]["audio"]
|
487 |
+
ref_text = speech_types[current_emotion].get("ref_text", "")
|
488 |
|
489 |
# Generate speech for this segment
|
490 |
audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
|
|
|
506 |
regular_audio,
|
507 |
regular_ref_text,
|
508 |
gen_text_input_emotional,
|
509 |
+
]
|
510 |
+
+ speech_type_names
|
511 |
+
+ speech_type_audios
|
512 |
+
+ speech_type_ref_texts
|
513 |
+
+ [
|
514 |
model_choice_emotional,
|
515 |
remove_silence_emotional,
|
516 |
],
|
|
|
518 |
)
|
519 |
|
520 |
# Validation function to disable Generate button if speech types are missing
|
521 |
+
def validate_speech_types(gen_text, regular_name, *args):
|
|
|
|
|
|
|
|
|
522 |
num_additional_speech_types = max_speech_types - 1
|
523 |
speech_type_names_list = args[:num_additional_speech_types]
|
524 |
|
|
|
532 |
|
533 |
# Parse the gen_text to get the speech types used
|
534 |
segments = parse_emotional_text(gen_text)
|
535 |
+
speech_types_in_text = set(segment["emotion"] for segment in segments)
|
536 |
|
537 |
# Check if all speech types in text are available
|
538 |
missing_speech_types = speech_types_in_text - speech_types_available
|
|
|
547 |
gen_text_input_emotional.change(
|
548 |
validate_speech_types,
|
549 |
inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
|
550 |
+
outputs=generate_emotional_btn,
|
551 |
)
|
552 |
with gr.Blocks() as app:
|
553 |
gr.Markdown(
|
|
|
568 |
)
|
569 |
gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
|
570 |
|
571 |
+
|
572 |
@click.command()
|
573 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
574 |
@click.option("--host", "-H", default=None, help="Host to run the app on")
|
|
|
582 |
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
|
583 |
def main(port, host, share, api):
|
584 |
global app
|
585 |
+
print("Starting app...")
|
586 |
+
app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
|
|
|
|
|
587 |
|
588 |
|
589 |
if __name__ == "__main__":
|
finetune-cli.py
CHANGED
@@ -1,42 +1,57 @@
|
|
1 |
import argparse
|
2 |
-
from model import CFM, UNetT, DiT,
|
3 |
from model.utils import get_tokenizer
|
4 |
from model.dataset import load_dataset
|
5 |
from cached_path import cached_path
|
6 |
-
import shutil
|
|
|
|
|
7 |
# -------------------------- Dataset Settings --------------------------- #
|
8 |
target_sample_rate = 24000
|
9 |
n_mel_channels = 100
|
10 |
hop_length = 256
|
11 |
|
12 |
-
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
13 |
-
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
14 |
|
15 |
# -------------------------- Argument Parsing --------------------------- #
|
16 |
def parse_args():
|
17 |
-
parser = argparse.ArgumentParser(description=
|
18 |
-
|
19 |
-
parser.add_argument(
|
20 |
-
|
21 |
-
|
22 |
-
parser.add_argument(
|
23 |
-
parser.add_argument(
|
24 |
-
parser.add_argument(
|
25 |
-
parser.add_argument(
|
26 |
-
|
27 |
-
|
28 |
-
parser.add_argument(
|
29 |
-
parser.add_argument(
|
30 |
-
parser.add_argument(
|
31 |
-
parser.add_argument(
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
return parser.parse_args()
|
34 |
|
|
|
35 |
# -------------------------- Training Settings -------------------------- #
|
36 |
|
|
|
37 |
def main():
|
38 |
args = parse_args()
|
39 |
-
|
40 |
|
41 |
# Model parameters based on experiment name
|
42 |
if args.exp_name == "F5TTS_Base":
|
@@ -44,24 +59,31 @@ def main():
|
|
44 |
model_cls = DiT
|
45 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
46 |
if args.finetune:
|
47 |
-
|
48 |
elif args.exp_name == "E2TTS_Base":
|
49 |
wandb_resume_id = None
|
50 |
model_cls = UNetT
|
51 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
52 |
if args.finetune:
|
53 |
-
|
54 |
-
|
55 |
if args.finetune:
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
checkpoint_path=os.path.join("ckpts",args.dataset_name)
|
62 |
-
|
63 |
-
# Use the
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
66 |
|
67 |
mel_spec_kwargs = dict(
|
@@ -71,11 +93,7 @@ def main():
|
|
71 |
)
|
72 |
|
73 |
e2tts = CFM(
|
74 |
-
transformer=model_cls(
|
75 |
-
**model_cfg,
|
76 |
-
text_num_embeds=vocab_size,
|
77 |
-
mel_dim=n_mel_channels
|
78 |
-
),
|
79 |
mel_spec_kwargs=mel_spec_kwargs,
|
80 |
vocab_char_map=vocab_char_map,
|
81 |
)
|
@@ -99,10 +117,11 @@ def main():
|
|
99 |
)
|
100 |
|
101 |
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
102 |
-
trainer.train(
|
103 |
-
|
104 |
-
|
|
|
105 |
|
106 |
|
107 |
-
if __name__ ==
|
108 |
main()
|
|
|
1 |
import argparse
|
2 |
+
from model import CFM, UNetT, DiT, Trainer
|
3 |
from model.utils import get_tokenizer
|
4 |
from model.dataset import load_dataset
|
5 |
from cached_path import cached_path
|
6 |
+
import shutil
|
7 |
+
import os
|
8 |
+
|
9 |
# -------------------------- Dataset Settings --------------------------- #
|
10 |
target_sample_rate = 24000
|
11 |
n_mel_channels = 100
|
12 |
hop_length = 256
|
13 |
|
|
|
|
|
14 |
|
15 |
# -------------------------- Argument Parsing --------------------------- #
|
16 |
def parse_args():
|
17 |
+
parser = argparse.ArgumentParser(description="Train CFM Model")
|
18 |
+
|
19 |
+
parser.add_argument(
|
20 |
+
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
|
21 |
+
)
|
22 |
+
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
23 |
+
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
|
24 |
+
parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
|
25 |
+
parser.add_argument(
|
26 |
+
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
|
27 |
+
)
|
28 |
+
parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
|
29 |
+
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
30 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
31 |
+
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
|
32 |
+
parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
|
33 |
+
parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
|
34 |
+
parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
|
35 |
+
parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
|
36 |
+
|
37 |
+
parser.add_argument(
|
38 |
+
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--tokenizer_path",
|
42 |
+
type=str,
|
43 |
+
default=None,
|
44 |
+
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
|
45 |
+
)
|
46 |
+
|
47 |
return parser.parse_args()
|
48 |
|
49 |
+
|
50 |
# -------------------------- Training Settings -------------------------- #
|
51 |
|
52 |
+
|
53 |
def main():
|
54 |
args = parse_args()
|
|
|
55 |
|
56 |
# Model parameters based on experiment name
|
57 |
if args.exp_name == "F5TTS_Base":
|
|
|
59 |
model_cls = DiT
|
60 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
61 |
if args.finetune:
|
62 |
+
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
63 |
elif args.exp_name == "E2TTS_Base":
|
64 |
wandb_resume_id = None
|
65 |
model_cls = UNetT
|
66 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
67 |
if args.finetune:
|
68 |
+
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
69 |
+
|
70 |
if args.finetune:
|
71 |
+
path_ckpt = os.path.join("ckpts", args.dataset_name)
|
72 |
+
if not os.path.isdir(path_ckpt):
|
73 |
+
os.makedirs(path_ckpt, exist_ok=True)
|
74 |
+
shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
|
75 |
+
|
76 |
+
checkpoint_path = os.path.join("ckpts", args.dataset_name)
|
77 |
+
|
78 |
+
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
79 |
+
tokenizer = args.tokenizer
|
80 |
+
if tokenizer == "custom":
|
81 |
+
if not args.tokenizer_path:
|
82 |
+
raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
|
83 |
+
tokenizer_path = args.tokenizer_path
|
84 |
+
else:
|
85 |
+
tokenizer_path = args.dataset_name
|
86 |
+
|
87 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
88 |
|
89 |
mel_spec_kwargs = dict(
|
|
|
93 |
)
|
94 |
|
95 |
e2tts = CFM(
|
96 |
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
|
|
|
|
|
|
|
|
97 |
mel_spec_kwargs=mel_spec_kwargs,
|
98 |
vocab_char_map=vocab_char_map,
|
99 |
)
|
|
|
117 |
)
|
118 |
|
119 |
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
120 |
+
trainer.train(
|
121 |
+
train_dataset,
|
122 |
+
resumable_with_seed=666, # seed for shuffling dataset
|
123 |
+
)
|
124 |
|
125 |
|
126 |
+
if __name__ == "__main__":
|
127 |
main()
|
finetune_gradio.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
import os
|
|
|
2 |
|
3 |
from transformers import pipeline
|
4 |
import gradio as gr
|
@@ -20,34 +21,37 @@ import platform
|
|
20 |
import subprocess
|
21 |
from datasets.arrow_writer import ArrowWriter
|
22 |
|
23 |
-
import json
|
24 |
|
25 |
-
training_process = None
|
26 |
system = platform.system()
|
27 |
python_executable = sys.executable or "python"
|
28 |
|
29 |
-
path_data="data"
|
30 |
|
31 |
-
device = (
|
32 |
-
"cuda"
|
33 |
-
if torch.cuda.is_available()
|
34 |
-
else "mps" if torch.backends.mps.is_available() else "cpu"
|
35 |
-
)
|
36 |
|
37 |
pipe = None
|
38 |
|
|
|
39 |
# Load metadata
|
40 |
def get_audio_duration(audio_path):
|
41 |
"""Calculate the duration of an audio file."""
|
42 |
audio, sample_rate = torchaudio.load(audio_path)
|
43 |
-
num_channels = audio.shape[0]
|
44 |
return audio.shape[1] / (sample_rate * num_channels)
|
45 |
|
|
|
46 |
def clear_text(text):
|
47 |
"""Clean and prepare text by lowering the case and stripping whitespace."""
|
48 |
return text.lower().strip()
|
49 |
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
padding = (int(frame_length // 2), int(frame_length // 2))
|
52 |
y = np.pad(y, padding, mode=pad_mode)
|
53 |
|
@@ -74,7 +78,8 @@ def get_rms(y,frame_length=2048,hop_length=512,pad_mode="constant",): # https://
|
|
74 |
|
75 |
return np.sqrt(power)
|
76 |
|
77 |
-
|
|
|
78 |
def __init__(
|
79 |
self,
|
80 |
sr: int,
|
@@ -85,13 +90,9 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
|
|
85 |
max_sil_kept: int = 2000,
|
86 |
):
|
87 |
if not min_length >= min_interval >= hop_size:
|
88 |
-
raise ValueError(
|
89 |
-
"The following condition must be satisfied: min_length >= min_interval >= hop_size"
|
90 |
-
)
|
91 |
if not max_sil_kept >= hop_size:
|
92 |
-
raise ValueError(
|
93 |
-
"The following condition must be satisfied: max_sil_kept >= hop_size"
|
94 |
-
)
|
95 |
min_interval = sr * min_interval / 1000
|
96 |
self.threshold = 10 ** (threshold / 20.0)
|
97 |
self.hop_size = round(sr * hop_size / 1000)
|
@@ -102,13 +103,9 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
|
|
102 |
|
103 |
def _apply_slice(self, waveform, begin, end):
|
104 |
if len(waveform.shape) > 1:
|
105 |
-
return waveform[
|
106 |
-
:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
|
107 |
-
]
|
108 |
else:
|
109 |
-
return waveform[
|
110 |
-
begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
|
111 |
-
]
|
112 |
|
113 |
# @timeit
|
114 |
def slice(self, waveform):
|
@@ -118,9 +115,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
|
|
118 |
samples = waveform
|
119 |
if samples.shape[0] <= self.min_length:
|
120 |
return [waveform]
|
121 |
-
rms_list = get_rms(
|
122 |
-
y=samples, frame_length=self.win_size, hop_length=self.hop_size
|
123 |
-
).squeeze(0)
|
124 |
sil_tags = []
|
125 |
silence_start = None
|
126 |
clip_start = 0
|
@@ -136,10 +131,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
|
|
136 |
continue
|
137 |
# Clear recorded silence start if interval is not enough or clip is too short
|
138 |
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
139 |
-
need_slice_middle =
|
140 |
-
i - silence_start >= self.min_interval
|
141 |
-
and i - clip_start >= self.min_length
|
142 |
-
)
|
143 |
if not is_leading_silence and not need_slice_middle:
|
144 |
silence_start = None
|
145 |
continue
|
@@ -152,21 +144,10 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
|
|
152 |
sil_tags.append((pos, pos))
|
153 |
clip_start = pos
|
154 |
elif i - silence_start <= self.max_sil_kept * 2:
|
155 |
-
pos = rms_list[
|
156 |
-
i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
|
157 |
-
].argmin()
|
158 |
pos += i - self.max_sil_kept
|
159 |
-
pos_l = (
|
160 |
-
|
161 |
-
silence_start : silence_start + self.max_sil_kept + 1
|
162 |
-
].argmin()
|
163 |
-
+ silence_start
|
164 |
-
)
|
165 |
-
pos_r = (
|
166 |
-
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
167 |
-
+ i
|
168 |
-
- self.max_sil_kept
|
169 |
-
)
|
170 |
if silence_start == 0:
|
171 |
sil_tags.append((0, pos_r))
|
172 |
clip_start = pos_r
|
@@ -174,17 +155,8 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
|
|
174 |
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
175 |
clip_start = max(pos_r, pos)
|
176 |
else:
|
177 |
-
pos_l = (
|
178 |
-
|
179 |
-
silence_start : silence_start + self.max_sil_kept + 1
|
180 |
-
].argmin()
|
181 |
-
+ silence_start
|
182 |
-
)
|
183 |
-
pos_r = (
|
184 |
-
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
185 |
-
+ i
|
186 |
-
- self.max_sil_kept
|
187 |
-
)
|
188 |
if silence_start == 0:
|
189 |
sil_tags.append((0, pos_r))
|
190 |
else:
|
@@ -193,33 +165,39 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
|
|
193 |
silence_start = None
|
194 |
# Deal with trailing silence.
|
195 |
total_frames = rms_list.shape[0]
|
196 |
-
if
|
197 |
-
silence_start is not None
|
198 |
-
and total_frames - silence_start >= self.min_interval
|
199 |
-
):
|
200 |
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
201 |
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
202 |
sil_tags.append((pos, total_frames + 1))
|
203 |
# Apply and return slices.
|
204 |
####音频+起始时间+终止时间
|
205 |
if len(sil_tags) == 0:
|
206 |
-
return [[waveform,0,int(total_frames*self.hop_size)]]
|
207 |
else:
|
208 |
chunks = []
|
209 |
if sil_tags[0][0] > 0:
|
210 |
-
chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]),0,int(sil_tags[0][0]*self.hop_size)])
|
211 |
for i in range(len(sil_tags) - 1):
|
212 |
chunks.append(
|
213 |
-
[
|
|
|
|
|
|
|
|
|
214 |
)
|
215 |
if sil_tags[-1][1] < total_frames:
|
216 |
chunks.append(
|
217 |
-
[
|
|
|
|
|
|
|
|
|
218 |
)
|
219 |
return chunks
|
220 |
|
221 |
-
|
222 |
-
|
|
|
223 |
try:
|
224 |
parent = psutil.Process(pid)
|
225 |
except psutil.NoSuchProcess:
|
@@ -238,6 +216,7 @@ def terminate_process_tree(pid, including_parent=True):
|
|
238 |
except OSError:
|
239 |
pass
|
240 |
|
|
|
241 |
def terminate_process(pid):
|
242 |
if system == "Windows":
|
243 |
cmd = f"taskkill /t /f /pid {pid}"
|
@@ -245,132 +224,154 @@ def terminate_process(pid):
|
|
245 |
else:
|
246 |
terminate_process_tree(pid)
|
247 |
|
248 |
-
def start_training(dataset_name="",
|
249 |
-
exp_name="F5TTS_Base",
|
250 |
-
learning_rate=1e-4,
|
251 |
-
batch_size_per_gpu=400,
|
252 |
-
batch_size_type="frame",
|
253 |
-
max_samples=64,
|
254 |
-
grad_accumulation_steps=1,
|
255 |
-
max_grad_norm=1.0,
|
256 |
-
epochs=11,
|
257 |
-
num_warmup_updates=200,
|
258 |
-
save_per_updates=400,
|
259 |
-
last_per_steps=800,
|
260 |
-
finetune=True,
|
261 |
-
):
|
262 |
-
|
263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
global training_process
|
265 |
|
266 |
path_project = os.path.join(path_data, dataset_name + "_pinyin")
|
267 |
|
268 |
-
if os.path.isdir(path_project)
|
269 |
-
yield
|
|
|
|
|
|
|
|
|
270 |
return
|
271 |
|
272 |
-
file_raw = os.path.join(path_project,"raw.arrow")
|
273 |
-
if os.path.isfile(file_raw)
|
274 |
-
|
275 |
-
|
276 |
|
277 |
# Check if a training process is already running
|
278 |
if training_process is not None:
|
279 |
-
return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
|
280 |
|
281 |
-
yield "start train",gr.update(interactive=False),gr.update(interactive=False)
|
282 |
|
283 |
# Command to run the training script with the specified arguments
|
284 |
-
cmd =
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
297 |
|
298 |
print(cmd)
|
299 |
-
|
300 |
try:
|
301 |
-
|
302 |
-
|
303 |
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
|
316 |
except Exception as e: # Catch all exceptions
|
317 |
# Ensure that we reset the training process variable in case of an error
|
318 |
-
text_info=f"An error occurred: {str(e)}"
|
319 |
-
|
320 |
-
training_process=None
|
|
|
|
|
321 |
|
322 |
-
yield text_info,gr.update(interactive=True),gr.update(interactive=False)
|
323 |
|
324 |
def stop_training():
|
325 |
global training_process
|
326 |
-
if training_process is None:
|
|
|
327 |
terminate_process_tree(training_process.pid)
|
328 |
training_process = None
|
329 |
-
return
|
|
|
330 |
|
331 |
def create_data_project(name):
|
332 |
-
name+="_pinyin"
|
333 |
-
os.makedirs(os.path.join(path_data,name),exist_ok=True)
|
334 |
-
os.makedirs(os.path.join(path_data,name,"dataset"),exist_ok=True)
|
335 |
-
|
336 |
-
|
|
|
337 |
global pipe
|
338 |
|
339 |
if pipe is None:
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
341 |
|
342 |
text_transcribe = pipe(
|
343 |
file_audio,
|
344 |
chunk_length_s=30,
|
345 |
batch_size=128,
|
346 |
-
generate_kwargs={"task": "transcribe","language": language},
|
347 |
return_timestamps=False,
|
348 |
)["text"].strip()
|
349 |
return text_transcribe
|
350 |
|
351 |
-
def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Progress()):
|
352 |
-
name_project+="_pinyin"
|
353 |
-
path_project= os.path.join(path_data,name_project)
|
354 |
-
path_dataset = os.path.join(path_project,"dataset")
|
355 |
-
path_project_wavs = os.path.join(path_project,"wavs")
|
356 |
-
file_metadata = os.path.join(path_project,"metadata.csv")
|
357 |
|
358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
if os.path.isdir(path_project_wavs):
|
361 |
-
|
362 |
|
363 |
if os.path.isfile(file_metadata):
|
364 |
-
|
|
|
|
|
365 |
|
366 |
-
os.makedirs(path_project_wavs,exist_ok=True)
|
367 |
-
|
368 |
if user:
|
369 |
-
|
370 |
-
|
|
|
|
|
|
|
|
|
|
|
371 |
else:
|
372 |
-
|
373 |
-
|
374 |
|
375 |
alpha = 0.5
|
376 |
_max = 1.0
|
@@ -378,181 +379,202 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
|
|
378 |
|
379 |
num = 0
|
380 |
error_num = 0
|
381 |
-
data=""
|
382 |
-
for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
list_slicer=
|
387 |
-
for chunk, start, end in progress.tqdm(list_slicer,total=len(list_slicer), desc="slicer files"):
|
388 |
-
|
389 |
name_segment = os.path.join(f"segment_{num}")
|
390 |
-
file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
|
391 |
-
|
392 |
tmp_max = np.abs(chunk).max()
|
393 |
-
if
|
|
|
394 |
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
|
395 |
-
wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
|
396 |
-
|
397 |
try:
|
398 |
-
|
399 |
-
|
400 |
|
401 |
-
|
402 |
|
403 |
-
|
404 |
-
except:
|
405 |
-
|
406 |
|
407 |
-
with open(file_metadata,"w",encoding="utf-8") as f:
|
408 |
f.write(data)
|
409 |
-
|
410 |
-
if error_num!=[]:
|
411 |
-
|
412 |
else:
|
413 |
-
|
414 |
-
|
415 |
return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
|
416 |
|
|
|
417 |
def format_seconds_to_hms(seconds):
|
418 |
hours = int(seconds / 3600)
|
419 |
minutes = int((seconds % 3600) / 60)
|
420 |
seconds = seconds % 60
|
421 |
return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
|
422 |
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
|
|
|
|
|
|
449 |
|
450 |
file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
|
451 |
|
452 |
-
if os.path.isfile(file_audio)
|
453 |
error_files.append(file_audio)
|
454 |
continue
|
455 |
|
456 |
duraction = get_audio_duration(file_audio)
|
457 |
-
if duraction<2 and duraction>15:
|
458 |
-
|
|
|
|
|
459 |
|
460 |
text = clear_text(text)
|
461 |
-
text = convert_char_to_pinyin([text], polyphone
|
462 |
|
463 |
audio_path_list.append(file_audio)
|
464 |
duration_list.append(duraction)
|
465 |
text_list.append(text)
|
466 |
-
|
467 |
result.append({"audio_path": file_audio, "text": text, "duration": duraction})
|
468 |
|
469 |
-
lenght+=duraction
|
470 |
|
471 |
-
if duration_list==[]:
|
472 |
-
error_files_text="\n".join(error_files)
|
473 |
return f"Error: No audio files found in the specified path : \n{error_files_text}"
|
474 |
-
|
475 |
-
min_second = round(min(duration_list),2)
|
476 |
-
max_second = round(max(duration_list),2)
|
477 |
|
478 |
with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
|
479 |
-
for line in progress.tqdm(result,total=len(result), desc=
|
480 |
writer.write(line)
|
481 |
|
482 |
-
with open(file_duration,
|
483 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
484 |
-
|
485 |
-
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
486 |
-
if os.path.isfile(file_vocab_finetune
|
|
|
487 |
shutil.copy2(file_vocab_finetune, file_vocab)
|
488 |
-
|
489 |
-
if error_files!=[]:
|
490 |
-
|
491 |
else:
|
492 |
-
|
493 |
-
|
494 |
return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
|
495 |
|
|
|
496 |
def check_user(value):
|
497 |
-
return gr.update(visible=not value),gr.update(visible=value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
|
499 |
-
|
500 |
-
|
501 |
-
path_project= os.path.join(path_data,name_project)
|
502 |
-
file_duraction = os.path.join(path_project,"duration.json")
|
503 |
|
504 |
-
|
505 |
-
data = json.load(file)
|
506 |
-
|
507 |
-
duration_list = data['duration']
|
508 |
|
509 |
samples = len(duration_list)
|
510 |
|
511 |
if torch.cuda.is_available():
|
512 |
gpu_properties = torch.cuda.get_device_properties(0)
|
513 |
-
total_memory = gpu_properties.total_memory / (1024
|
514 |
elif torch.backends.mps.is_available():
|
515 |
-
total_memory = psutil.virtual_memory().available / (1024
|
516 |
-
|
517 |
-
if batch_size_type=="frame":
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
else:
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
|
526 |
-
if batch_size_per_gpu<=0:
|
|
|
527 |
|
528 |
-
if samples<64:
|
529 |
-
|
530 |
else:
|
531 |
-
|
532 |
-
|
533 |
-
num_warmup_updates = int(samples * 0.10)
|
534 |
-
save_per_updates = int(samples * 0.25)
|
535 |
-
last_per_steps =int(save_per_updates * 5)
|
536 |
-
|
537 |
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
538 |
num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
|
539 |
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
|
540 |
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
541 |
|
542 |
-
if finetune:
|
543 |
-
|
|
|
|
|
|
|
|
|
544 |
|
545 |
-
return batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,samples,learning_rate
|
546 |
|
547 |
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
|
548 |
try:
|
549 |
checkpoint = torch.load(checkpoint_path)
|
550 |
print("Original Checkpoint Keys:", checkpoint.keys())
|
551 |
-
|
552 |
-
ema_model_state_dict = checkpoint.get(
|
553 |
|
554 |
if ema_model_state_dict is not None:
|
555 |
-
new_checkpoint = {
|
556 |
torch.save(new_checkpoint, new_checkpoint_path)
|
557 |
return f"New checkpoint saved at: {new_checkpoint_path}"
|
558 |
else:
|
@@ -561,65 +583,61 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -
|
|
561 |
except Exception as e:
|
562 |
return f"An error occurred: {e}"
|
563 |
|
|
|
564 |
def vocab_check(project_name):
|
565 |
name_project = project_name + "_pinyin"
|
566 |
path_project = os.path.join(path_data, name_project)
|
567 |
|
568 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
569 |
-
|
570 |
-
file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
|
571 |
-
if os.path.isfile(file_vocab)
|
572 |
return f"the file {file_vocab} not found !"
|
573 |
-
|
574 |
-
with open(file_vocab,"r",encoding="utf-8") as f:
|
575 |
-
|
576 |
|
577 |
vocab = data.split("\n")
|
578 |
|
579 |
-
if os.path.isfile(file_metadata)
|
580 |
return f"the file {file_metadata} not found !"
|
581 |
|
582 |
-
with open(file_metadata,"r",encoding="utf-8") as f:
|
583 |
-
|
584 |
|
585 |
-
miss_symbols=[]
|
586 |
-
miss_symbols_keep={}
|
587 |
for item in data.split("\n"):
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
else:
|
|
|
600 |
|
601 |
return info
|
602 |
|
603 |
|
604 |
-
|
605 |
with gr.Blocks() as app:
|
606 |
-
|
607 |
with gr.Row():
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
bt_create.click(fn=create_data_project,inputs=[project_name])
|
612 |
-
|
613 |
-
with gr.Tabs():
|
614 |
-
|
615 |
|
616 |
-
|
617 |
|
|
|
|
|
|
|
618 |
|
619 |
-
|
620 |
-
|
621 |
-
mark_info_transcribe=gr.Markdown(
|
622 |
-
"""```plaintext
|
623 |
Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
|
624 |
|
625 |
my_speak/
|
@@ -628,18 +646,24 @@ with gr.Blocks() as app:
|
|
628 |
├── audio1.wav
|
629 |
└── audio2.wav
|
630 |
...
|
631 |
-
```""",
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
643 |
place all your wavs folder and your metadata.csv file in {your name project}
|
644 |
my_speak/
|
645 |
│
|
@@ -656,61 +680,104 @@ with gr.Blocks() as app:
|
|
656 |
audio2|text1
|
657 |
...
|
658 |
|
659 |
-
```"""
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
714 |
|
715 |
@click.command()
|
716 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
@@ -725,10 +792,9 @@ with gr.Blocks() as app:
|
|
725 |
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
|
726 |
def main(port, host, share, api):
|
727 |
global app
|
728 |
-
print(
|
729 |
-
app.queue(api_open=api).launch(
|
730 |
-
|
731 |
-
)
|
732 |
|
733 |
if __name__ == "__main__":
|
734 |
main()
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
|
4 |
from transformers import pipeline
|
5 |
import gradio as gr
|
|
|
21 |
import subprocess
|
22 |
from datasets.arrow_writer import ArrowWriter
|
23 |
|
|
|
24 |
|
25 |
+
training_process = None
|
26 |
system = platform.system()
|
27 |
python_executable = sys.executable or "python"
|
28 |
|
29 |
+
path_data = "data"
|
30 |
|
31 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
32 |
|
33 |
pipe = None
|
34 |
|
35 |
+
|
36 |
# Load metadata
|
37 |
def get_audio_duration(audio_path):
|
38 |
"""Calculate the duration of an audio file."""
|
39 |
audio, sample_rate = torchaudio.load(audio_path)
|
40 |
+
num_channels = audio.shape[0]
|
41 |
return audio.shape[1] / (sample_rate * num_channels)
|
42 |
|
43 |
+
|
44 |
def clear_text(text):
|
45 |
"""Clean and prepare text by lowering the case and stripping whitespace."""
|
46 |
return text.lower().strip()
|
47 |
|
48 |
+
|
49 |
+
def get_rms(
|
50 |
+
y,
|
51 |
+
frame_length=2048,
|
52 |
+
hop_length=512,
|
53 |
+
pad_mode="constant",
|
54 |
+
): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
|
55 |
padding = (int(frame_length // 2), int(frame_length // 2))
|
56 |
y = np.pad(y, padding, mode=pad_mode)
|
57 |
|
|
|
78 |
|
79 |
return np.sqrt(power)
|
80 |
|
81 |
+
|
82 |
+
class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
|
83 |
def __init__(
|
84 |
self,
|
85 |
sr: int,
|
|
|
90 |
max_sil_kept: int = 2000,
|
91 |
):
|
92 |
if not min_length >= min_interval >= hop_size:
|
93 |
+
raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
|
|
|
|
|
94 |
if not max_sil_kept >= hop_size:
|
95 |
+
raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
|
|
|
|
|
96 |
min_interval = sr * min_interval / 1000
|
97 |
self.threshold = 10 ** (threshold / 20.0)
|
98 |
self.hop_size = round(sr * hop_size / 1000)
|
|
|
103 |
|
104 |
def _apply_slice(self, waveform, begin, end):
|
105 |
if len(waveform.shape) > 1:
|
106 |
+
return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
|
|
|
|
|
107 |
else:
|
108 |
+
return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
|
|
|
|
|
109 |
|
110 |
# @timeit
|
111 |
def slice(self, waveform):
|
|
|
115 |
samples = waveform
|
116 |
if samples.shape[0] <= self.min_length:
|
117 |
return [waveform]
|
118 |
+
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
|
|
|
|
119 |
sil_tags = []
|
120 |
silence_start = None
|
121 |
clip_start = 0
|
|
|
131 |
continue
|
132 |
# Clear recorded silence start if interval is not enough or clip is too short
|
133 |
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
134 |
+
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
|
|
|
|
|
|
|
135 |
if not is_leading_silence and not need_slice_middle:
|
136 |
silence_start = None
|
137 |
continue
|
|
|
144 |
sil_tags.append((pos, pos))
|
145 |
clip_start = pos
|
146 |
elif i - silence_start <= self.max_sil_kept * 2:
|
147 |
+
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
|
|
|
|
148 |
pos += i - self.max_sil_kept
|
149 |
+
pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
150 |
+
pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
if silence_start == 0:
|
152 |
sil_tags.append((0, pos_r))
|
153 |
clip_start = pos_r
|
|
|
155 |
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
156 |
clip_start = max(pos_r, pos)
|
157 |
else:
|
158 |
+
pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
159 |
+
pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
if silence_start == 0:
|
161 |
sil_tags.append((0, pos_r))
|
162 |
else:
|
|
|
165 |
silence_start = None
|
166 |
# Deal with trailing silence.
|
167 |
total_frames = rms_list.shape[0]
|
168 |
+
if silence_start is not None and total_frames - silence_start >= self.min_interval:
|
|
|
|
|
|
|
169 |
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
170 |
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
171 |
sil_tags.append((pos, total_frames + 1))
|
172 |
# Apply and return slices.
|
173 |
####音频+起始时间+终止时间
|
174 |
if len(sil_tags) == 0:
|
175 |
+
return [[waveform, 0, int(total_frames * self.hop_size)]]
|
176 |
else:
|
177 |
chunks = []
|
178 |
if sil_tags[0][0] > 0:
|
179 |
+
chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
|
180 |
for i in range(len(sil_tags) - 1):
|
181 |
chunks.append(
|
182 |
+
[
|
183 |
+
self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
|
184 |
+
int(sil_tags[i][1] * self.hop_size),
|
185 |
+
int(sil_tags[i + 1][0] * self.hop_size),
|
186 |
+
]
|
187 |
)
|
188 |
if sil_tags[-1][1] < total_frames:
|
189 |
chunks.append(
|
190 |
+
[
|
191 |
+
self._apply_slice(waveform, sil_tags[-1][1], total_frames),
|
192 |
+
int(sil_tags[-1][1] * self.hop_size),
|
193 |
+
int(total_frames * self.hop_size),
|
194 |
+
]
|
195 |
)
|
196 |
return chunks
|
197 |
|
198 |
+
|
199 |
+
# terminal
|
200 |
+
def terminate_process_tree(pid, including_parent=True):
|
201 |
try:
|
202 |
parent = psutil.Process(pid)
|
203 |
except psutil.NoSuchProcess:
|
|
|
216 |
except OSError:
|
217 |
pass
|
218 |
|
219 |
+
|
220 |
def terminate_process(pid):
|
221 |
if system == "Windows":
|
222 |
cmd = f"taskkill /t /f /pid {pid}"
|
|
|
224 |
else:
|
225 |
terminate_process_tree(pid)
|
226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
+
def start_training(
|
229 |
+
dataset_name="",
|
230 |
+
exp_name="F5TTS_Base",
|
231 |
+
learning_rate=1e-4,
|
232 |
+
batch_size_per_gpu=400,
|
233 |
+
batch_size_type="frame",
|
234 |
+
max_samples=64,
|
235 |
+
grad_accumulation_steps=1,
|
236 |
+
max_grad_norm=1.0,
|
237 |
+
epochs=11,
|
238 |
+
num_warmup_updates=200,
|
239 |
+
save_per_updates=400,
|
240 |
+
last_per_steps=800,
|
241 |
+
finetune=True,
|
242 |
+
):
|
243 |
global training_process
|
244 |
|
245 |
path_project = os.path.join(path_data, dataset_name + "_pinyin")
|
246 |
|
247 |
+
if not os.path.isdir(path_project):
|
248 |
+
yield (
|
249 |
+
f"There is not project with name {dataset_name}",
|
250 |
+
gr.update(interactive=True),
|
251 |
+
gr.update(interactive=False),
|
252 |
+
)
|
253 |
return
|
254 |
|
255 |
+
file_raw = os.path.join(path_project, "raw.arrow")
|
256 |
+
if not os.path.isfile(file_raw):
|
257 |
+
yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False)
|
258 |
+
return
|
259 |
|
260 |
# Check if a training process is already running
|
261 |
if training_process is not None:
|
262 |
+
return "Train run already!", gr.update(interactive=False), gr.update(interactive=True)
|
263 |
|
264 |
+
yield "start train", gr.update(interactive=False), gr.update(interactive=False)
|
265 |
|
266 |
# Command to run the training script with the specified arguments
|
267 |
+
cmd = (
|
268 |
+
f"accelerate launch finetune-cli.py --exp_name {exp_name} "
|
269 |
+
f"--learning_rate {learning_rate} "
|
270 |
+
f"--batch_size_per_gpu {batch_size_per_gpu} "
|
271 |
+
f"--batch_size_type {batch_size_type} "
|
272 |
+
f"--max_samples {max_samples} "
|
273 |
+
f"--grad_accumulation_steps {grad_accumulation_steps} "
|
274 |
+
f"--max_grad_norm {max_grad_norm} "
|
275 |
+
f"--epochs {epochs} "
|
276 |
+
f"--num_warmup_updates {num_warmup_updates} "
|
277 |
+
f"--save_per_updates {save_per_updates} "
|
278 |
+
f"--last_per_steps {last_per_steps} "
|
279 |
+
f"--dataset_name {dataset_name}"
|
280 |
+
)
|
281 |
+
if finetune:
|
282 |
+
cmd += f" --finetune {finetune}"
|
283 |
|
284 |
print(cmd)
|
285 |
+
|
286 |
try:
|
287 |
+
# Start the training process
|
288 |
+
training_process = subprocess.Popen(cmd, shell=True)
|
289 |
|
290 |
+
time.sleep(5)
|
291 |
+
yield "check terminal for wandb", gr.update(interactive=False), gr.update(interactive=True)
|
292 |
+
|
293 |
+
# Wait for the training process to finish
|
294 |
+
training_process.wait()
|
295 |
+
time.sleep(1)
|
296 |
+
|
297 |
+
if training_process is None:
|
298 |
+
text_info = "train stop"
|
299 |
+
else:
|
300 |
+
text_info = "train complete !"
|
301 |
|
302 |
except Exception as e: # Catch all exceptions
|
303 |
# Ensure that we reset the training process variable in case of an error
|
304 |
+
text_info = f"An error occurred: {str(e)}"
|
305 |
+
|
306 |
+
training_process = None
|
307 |
+
|
308 |
+
yield text_info, gr.update(interactive=True), gr.update(interactive=False)
|
309 |
|
|
|
310 |
|
311 |
def stop_training():
|
312 |
global training_process
|
313 |
+
if training_process is None:
|
314 |
+
return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
|
315 |
terminate_process_tree(training_process.pid)
|
316 |
training_process = None
|
317 |
+
return "train stop", gr.update(interactive=True), gr.update(interactive=False)
|
318 |
+
|
319 |
|
320 |
def create_data_project(name):
|
321 |
+
name += "_pinyin"
|
322 |
+
os.makedirs(os.path.join(path_data, name), exist_ok=True)
|
323 |
+
os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
|
324 |
+
|
325 |
+
|
326 |
+
def transcribe(file_audio, language="english"):
|
327 |
global pipe
|
328 |
|
329 |
if pipe is None:
|
330 |
+
pipe = pipeline(
|
331 |
+
"automatic-speech-recognition",
|
332 |
+
model="openai/whisper-large-v3-turbo",
|
333 |
+
torch_dtype=torch.float16,
|
334 |
+
device=device,
|
335 |
+
)
|
336 |
|
337 |
text_transcribe = pipe(
|
338 |
file_audio,
|
339 |
chunk_length_s=30,
|
340 |
batch_size=128,
|
341 |
+
generate_kwargs={"task": "transcribe", "language": language},
|
342 |
return_timestamps=False,
|
343 |
)["text"].strip()
|
344 |
return text_transcribe
|
345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
|
347 |
+
def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
|
348 |
+
name_project += "_pinyin"
|
349 |
+
path_project = os.path.join(path_data, name_project)
|
350 |
+
path_dataset = os.path.join(path_project, "dataset")
|
351 |
+
path_project_wavs = os.path.join(path_project, "wavs")
|
352 |
+
file_metadata = os.path.join(path_project, "metadata.csv")
|
353 |
+
|
354 |
+
if audio_files is None:
|
355 |
+
return "You need to load an audio file."
|
356 |
|
357 |
if os.path.isdir(path_project_wavs):
|
358 |
+
shutil.rmtree(path_project_wavs)
|
359 |
|
360 |
if os.path.isfile(file_metadata):
|
361 |
+
os.remove(file_metadata)
|
362 |
+
|
363 |
+
os.makedirs(path_project_wavs, exist_ok=True)
|
364 |
|
|
|
|
|
365 |
if user:
|
366 |
+
file_audios = [
|
367 |
+
file
|
368 |
+
for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac")
|
369 |
+
for file in glob(os.path.join(path_dataset, format))
|
370 |
+
]
|
371 |
+
if file_audios == []:
|
372 |
+
return "No audio file was found in the dataset."
|
373 |
else:
|
374 |
+
file_audios = audio_files
|
|
|
375 |
|
376 |
alpha = 0.5
|
377 |
_max = 1.0
|
|
|
379 |
|
380 |
num = 0
|
381 |
error_num = 0
|
382 |
+
data = ""
|
383 |
+
for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))):
|
384 |
+
audio, _ = librosa.load(file_audio, sr=24000, mono=True)
|
385 |
+
|
386 |
+
list_slicer = slicer.slice(audio)
|
387 |
+
for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"):
|
|
|
|
|
388 |
name_segment = os.path.join(f"segment_{num}")
|
389 |
+
file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
|
390 |
+
|
391 |
tmp_max = np.abs(chunk).max()
|
392 |
+
if tmp_max > 1:
|
393 |
+
chunk /= tmp_max
|
394 |
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
|
395 |
+
wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16))
|
396 |
+
|
397 |
try:
|
398 |
+
text = transcribe(file_segment, language)
|
399 |
+
text = text.lower().strip().replace('"', "")
|
400 |
|
401 |
+
data += f"{name_segment}|{text}\n"
|
402 |
|
403 |
+
num += 1
|
404 |
+
except: # noqa: E722
|
405 |
+
error_num += 1
|
406 |
|
407 |
+
with open(file_metadata, "w", encoding="utf-8") as f:
|
408 |
f.write(data)
|
409 |
+
|
410 |
+
if error_num != []:
|
411 |
+
error_text = f"\nerror files : {error_num}"
|
412 |
else:
|
413 |
+
error_text = ""
|
414 |
+
|
415 |
return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
|
416 |
|
417 |
+
|
418 |
def format_seconds_to_hms(seconds):
|
419 |
hours = int(seconds / 3600)
|
420 |
minutes = int((seconds % 3600) / 60)
|
421 |
seconds = seconds % 60
|
422 |
return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
|
423 |
|
424 |
+
|
425 |
+
def create_metadata(name_project, progress=gr.Progress()):
|
426 |
+
name_project += "_pinyin"
|
427 |
+
path_project = os.path.join(path_data, name_project)
|
428 |
+
path_project_wavs = os.path.join(path_project, "wavs")
|
429 |
+
file_metadata = os.path.join(path_project, "metadata.csv")
|
430 |
+
file_raw = os.path.join(path_project, "raw.arrow")
|
431 |
+
file_duration = os.path.join(path_project, "duration.json")
|
432 |
+
file_vocab = os.path.join(path_project, "vocab.txt")
|
433 |
+
|
434 |
+
if not os.path.isfile(file_metadata):
|
435 |
+
return "The file was not found in " + file_metadata
|
436 |
+
|
437 |
+
with open(file_metadata, "r", encoding="utf-8") as f:
|
438 |
+
data = f.read()
|
439 |
+
|
440 |
+
audio_path_list = []
|
441 |
+
text_list = []
|
442 |
+
duration_list = []
|
443 |
+
|
444 |
+
count = data.split("\n")
|
445 |
+
lenght = 0
|
446 |
+
result = []
|
447 |
+
error_files = []
|
448 |
+
for line in progress.tqdm(data.split("\n"), total=count):
|
449 |
+
sp_line = line.split("|")
|
450 |
+
if len(sp_line) != 2:
|
451 |
+
continue
|
452 |
+
name_audio, text = sp_line[:2]
|
453 |
|
454 |
file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
|
455 |
|
456 |
+
if not os.path.isfile(file_audio):
|
457 |
error_files.append(file_audio)
|
458 |
continue
|
459 |
|
460 |
duraction = get_audio_duration(file_audio)
|
461 |
+
if duraction < 2 and duraction > 15:
|
462 |
+
continue
|
463 |
+
if len(text) < 4:
|
464 |
+
continue
|
465 |
|
466 |
text = clear_text(text)
|
467 |
+
text = convert_char_to_pinyin([text], polyphone=True)[0]
|
468 |
|
469 |
audio_path_list.append(file_audio)
|
470 |
duration_list.append(duraction)
|
471 |
text_list.append(text)
|
472 |
+
|
473 |
result.append({"audio_path": file_audio, "text": text, "duration": duraction})
|
474 |
|
475 |
+
lenght += duraction
|
476 |
|
477 |
+
if duration_list == []:
|
478 |
+
error_files_text = "\n".join(error_files)
|
479 |
return f"Error: No audio files found in the specified path : \n{error_files_text}"
|
480 |
+
|
481 |
+
min_second = round(min(duration_list), 2)
|
482 |
+
max_second = round(max(duration_list), 2)
|
483 |
|
484 |
with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
|
485 |
+
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
486 |
writer.write(line)
|
487 |
|
488 |
+
with open(file_duration, "w", encoding="utf-8") as f:
|
489 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
490 |
+
|
491 |
+
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
492 |
+
if not os.path.isfile(file_vocab_finetune):
|
493 |
+
return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
|
494 |
shutil.copy2(file_vocab_finetune, file_vocab)
|
495 |
+
|
496 |
+
if error_files != []:
|
497 |
+
error_text = "error files\n" + "\n".join(error_files)
|
498 |
else:
|
499 |
+
error_text = ""
|
500 |
+
|
501 |
return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
|
502 |
|
503 |
+
|
504 |
def check_user(value):
|
505 |
+
return gr.update(visible=not value), gr.update(visible=value)
|
506 |
+
|
507 |
+
|
508 |
+
def calculate_train(
|
509 |
+
name_project,
|
510 |
+
batch_size_type,
|
511 |
+
max_samples,
|
512 |
+
learning_rate,
|
513 |
+
num_warmup_updates,
|
514 |
+
save_per_updates,
|
515 |
+
last_per_steps,
|
516 |
+
finetune,
|
517 |
+
):
|
518 |
+
name_project += "_pinyin"
|
519 |
+
path_project = os.path.join(path_data, name_project)
|
520 |
+
file_duraction = os.path.join(path_project, "duration.json")
|
521 |
|
522 |
+
with open(file_duraction, "r") as file:
|
523 |
+
data = json.load(file)
|
|
|
|
|
524 |
|
525 |
+
duration_list = data["duration"]
|
|
|
|
|
|
|
526 |
|
527 |
samples = len(duration_list)
|
528 |
|
529 |
if torch.cuda.is_available():
|
530 |
gpu_properties = torch.cuda.get_device_properties(0)
|
531 |
+
total_memory = gpu_properties.total_memory / (1024**3)
|
532 |
elif torch.backends.mps.is_available():
|
533 |
+
total_memory = psutil.virtual_memory().available / (1024**3)
|
534 |
+
|
535 |
+
if batch_size_type == "frame":
|
536 |
+
batch = int(total_memory * 0.5)
|
537 |
+
batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
|
538 |
+
batch_size_per_gpu = int(38400 / batch)
|
539 |
else:
|
540 |
+
batch_size_per_gpu = int(total_memory / 8)
|
541 |
+
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
|
542 |
+
batch = batch_size_per_gpu
|
543 |
|
544 |
+
if batch_size_per_gpu <= 0:
|
545 |
+
batch_size_per_gpu = 1
|
546 |
|
547 |
+
if samples < 64:
|
548 |
+
max_samples = int(samples * 0.25)
|
549 |
else:
|
550 |
+
max_samples = 64
|
551 |
+
|
552 |
+
num_warmup_updates = int(samples * 0.10)
|
553 |
+
save_per_updates = int(samples * 0.25)
|
554 |
+
last_per_steps = int(save_per_updates * 5)
|
555 |
+
|
556 |
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
557 |
num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
|
558 |
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
|
559 |
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
560 |
|
561 |
+
if finetune:
|
562 |
+
learning_rate = 1e-4
|
563 |
+
else:
|
564 |
+
learning_rate = 7.5e-5
|
565 |
+
|
566 |
+
return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate
|
567 |
|
|
|
568 |
|
569 |
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
|
570 |
try:
|
571 |
checkpoint = torch.load(checkpoint_path)
|
572 |
print("Original Checkpoint Keys:", checkpoint.keys())
|
573 |
+
|
574 |
+
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
|
575 |
|
576 |
if ema_model_state_dict is not None:
|
577 |
+
new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
|
578 |
torch.save(new_checkpoint, new_checkpoint_path)
|
579 |
return f"New checkpoint saved at: {new_checkpoint_path}"
|
580 |
else:
|
|
|
583 |
except Exception as e:
|
584 |
return f"An error occurred: {e}"
|
585 |
|
586 |
+
|
587 |
def vocab_check(project_name):
|
588 |
name_project = project_name + "_pinyin"
|
589 |
path_project = os.path.join(path_data, name_project)
|
590 |
|
591 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
592 |
+
|
593 |
+
file_vocab = "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
594 |
+
if not os.path.isfile(file_vocab):
|
595 |
return f"the file {file_vocab} not found !"
|
596 |
+
|
597 |
+
with open(file_vocab, "r", encoding="utf-8") as f:
|
598 |
+
data = f.read()
|
599 |
|
600 |
vocab = data.split("\n")
|
601 |
|
602 |
+
if not os.path.isfile(file_metadata):
|
603 |
return f"the file {file_metadata} not found !"
|
604 |
|
605 |
+
with open(file_metadata, "r", encoding="utf-8") as f:
|
606 |
+
data = f.read()
|
607 |
|
608 |
+
miss_symbols = []
|
609 |
+
miss_symbols_keep = {}
|
610 |
for item in data.split("\n"):
|
611 |
+
sp = item.split("|")
|
612 |
+
if len(sp) != 2:
|
613 |
+
continue
|
614 |
+
text = sp[1].lower().strip()
|
615 |
+
|
616 |
+
for t in text:
|
617 |
+
if t not in vocab and t not in miss_symbols_keep:
|
618 |
+
miss_symbols.append(t)
|
619 |
+
miss_symbols_keep[t] = t
|
620 |
+
if miss_symbols == []:
|
621 |
+
info = "You can train using your language !"
|
622 |
+
else:
|
623 |
+
info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
|
624 |
|
625 |
return info
|
626 |
|
627 |
|
|
|
628 |
with gr.Blocks() as app:
|
|
|
629 |
with gr.Row():
|
630 |
+
project_name = gr.Textbox(label="project name", value="my_speak")
|
631 |
+
bt_create = gr.Button("create new project")
|
|
|
|
|
|
|
|
|
|
|
632 |
|
633 |
+
bt_create.click(fn=create_data_project, inputs=[project_name])
|
634 |
|
635 |
+
with gr.Tabs():
|
636 |
+
with gr.TabItem("transcribe Data"):
|
637 |
+
ch_manual = gr.Checkbox(label="user", value=False)
|
638 |
|
639 |
+
mark_info_transcribe = gr.Markdown(
|
640 |
+
"""```plaintext
|
|
|
|
|
641 |
Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
|
642 |
|
643 |
my_speak/
|
|
|
646 |
├── audio1.wav
|
647 |
└── audio2.wav
|
648 |
...
|
649 |
+
```""",
|
650 |
+
visible=False,
|
651 |
+
)
|
652 |
+
|
653 |
+
audio_speaker = gr.File(label="voice", type="filepath", file_count="multiple")
|
654 |
+
txt_lang = gr.Text(label="Language", value="english")
|
655 |
+
bt_transcribe = bt_create = gr.Button("transcribe")
|
656 |
+
txt_info_transcribe = gr.Text(label="info", value="")
|
657 |
+
bt_transcribe.click(
|
658 |
+
fn=transcribe_all,
|
659 |
+
inputs=[project_name, audio_speaker, txt_lang, ch_manual],
|
660 |
+
outputs=[txt_info_transcribe],
|
661 |
+
)
|
662 |
+
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
|
663 |
+
|
664 |
+
with gr.TabItem("prepare Data"):
|
665 |
+
gr.Markdown(
|
666 |
+
"""```plaintext
|
667 |
place all your wavs folder and your metadata.csv file in {your name project}
|
668 |
my_speak/
|
669 |
│
|
|
|
680 |
audio2|text1
|
681 |
...
|
682 |
|
683 |
+
```"""
|
684 |
+
)
|
685 |
+
|
686 |
+
bt_prepare = bt_create = gr.Button("prepare")
|
687 |
+
txt_info_prepare = gr.Text(label="info", value="")
|
688 |
+
bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
|
689 |
+
|
690 |
+
with gr.TabItem("train Data"):
|
691 |
+
with gr.Row():
|
692 |
+
bt_calculate = bt_create = gr.Button("Auto Settings")
|
693 |
+
ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
|
694 |
+
lb_samples = gr.Label(label="samples")
|
695 |
+
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
|
696 |
+
|
697 |
+
with gr.Row():
|
698 |
+
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
699 |
+
learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4)
|
700 |
+
|
701 |
+
with gr.Row():
|
702 |
+
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
|
703 |
+
max_samples = gr.Number(label="Max Samples", value=16)
|
704 |
+
|
705 |
+
with gr.Row():
|
706 |
+
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
|
707 |
+
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
|
708 |
+
|
709 |
+
with gr.Row():
|
710 |
+
epochs = gr.Number(label="Epochs", value=10)
|
711 |
+
num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
|
712 |
+
|
713 |
+
with gr.Row():
|
714 |
+
save_per_updates = gr.Number(label="Save per Updates", value=10)
|
715 |
+
last_per_steps = gr.Number(label="Last per Steps", value=50)
|
716 |
+
|
717 |
+
with gr.Row():
|
718 |
+
start_button = gr.Button("Start Training")
|
719 |
+
stop_button = gr.Button("Stop Training", interactive=False)
|
720 |
+
|
721 |
+
txt_info_train = gr.Text(label="info", value="")
|
722 |
+
start_button.click(
|
723 |
+
fn=start_training,
|
724 |
+
inputs=[
|
725 |
+
project_name,
|
726 |
+
exp_name,
|
727 |
+
learning_rate,
|
728 |
+
batch_size_per_gpu,
|
729 |
+
batch_size_type,
|
730 |
+
max_samples,
|
731 |
+
grad_accumulation_steps,
|
732 |
+
max_grad_norm,
|
733 |
+
epochs,
|
734 |
+
num_warmup_updates,
|
735 |
+
save_per_updates,
|
736 |
+
last_per_steps,
|
737 |
+
ch_finetune,
|
738 |
+
],
|
739 |
+
outputs=[txt_info_train, start_button, stop_button],
|
740 |
+
)
|
741 |
+
stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
|
742 |
+
bt_calculate.click(
|
743 |
+
fn=calculate_train,
|
744 |
+
inputs=[
|
745 |
+
project_name,
|
746 |
+
batch_size_type,
|
747 |
+
max_samples,
|
748 |
+
learning_rate,
|
749 |
+
num_warmup_updates,
|
750 |
+
save_per_updates,
|
751 |
+
last_per_steps,
|
752 |
+
ch_finetune,
|
753 |
+
],
|
754 |
+
outputs=[
|
755 |
+
batch_size_per_gpu,
|
756 |
+
max_samples,
|
757 |
+
num_warmup_updates,
|
758 |
+
save_per_updates,
|
759 |
+
last_per_steps,
|
760 |
+
lb_samples,
|
761 |
+
learning_rate,
|
762 |
+
],
|
763 |
+
)
|
764 |
+
|
765 |
+
with gr.TabItem("reduse checkpoint"):
|
766 |
+
txt_path_checkpoint = gr.Text(label="path checkpoint :")
|
767 |
+
txt_path_checkpoint_small = gr.Text(label="path output :")
|
768 |
+
txt_info_reduse = gr.Text(label="info", value="")
|
769 |
+
reduse_button = gr.Button("reduse")
|
770 |
+
reduse_button.click(
|
771 |
+
fn=extract_and_save_ema_model,
|
772 |
+
inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
|
773 |
+
outputs=[txt_info_reduse],
|
774 |
+
)
|
775 |
+
|
776 |
+
with gr.TabItem("vocab check experiment"):
|
777 |
+
check_button = gr.Button("check vocab")
|
778 |
+
txt_info_check = gr.Text(label="info", value="")
|
779 |
+
check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
|
780 |
+
|
781 |
|
782 |
@click.command()
|
783 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
|
|
792 |
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
|
793 |
def main(port, host, share, api):
|
794 |
global app
|
795 |
+
print("Starting app...")
|
796 |
+
app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
|
797 |
+
|
|
|
798 |
|
799 |
if __name__ == "__main__":
|
800 |
main()
|
inference-cli.py
CHANGED
@@ -44,19 +44,8 @@ parser.add_argument(
|
|
44 |
"--vocab_file",
|
45 |
help="The vocab .txt",
|
46 |
)
|
47 |
-
parser.add_argument(
|
48 |
-
|
49 |
-
"--ref_audio",
|
50 |
-
type=str,
|
51 |
-
help="Reference audio file < 15 seconds."
|
52 |
-
)
|
53 |
-
parser.add_argument(
|
54 |
-
"-s",
|
55 |
-
"--ref_text",
|
56 |
-
type=str,
|
57 |
-
default="666",
|
58 |
-
help="Subtitle for the reference audio."
|
59 |
-
)
|
60 |
parser.add_argument(
|
61 |
"-t",
|
62 |
"--gen_text",
|
@@ -99,8 +88,8 @@ model = args.model if args.model else config["model"]
|
|
99 |
ckpt_file = args.ckpt_file if args.ckpt_file else ""
|
100 |
vocab_file = args.vocab_file if args.vocab_file else ""
|
101 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
102 |
-
wave_path = Path(output_dir)/"out.wav"
|
103 |
-
spectrogram_path = Path(output_dir)/"out.png"
|
104 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
105 |
|
106 |
vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
|
@@ -110,44 +99,46 @@ vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_loc
|
|
110 |
if model == "F5-TTS":
|
111 |
model_cls = DiT
|
112 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
113 |
-
if ckpt_file == "":
|
114 |
-
repo_name= "F5-TTS"
|
115 |
exp_name = "F5TTS_Base"
|
116 |
-
ckpt_step= 1200000
|
117 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
118 |
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
119 |
|
120 |
elif model == "E2-TTS":
|
121 |
model_cls = UNetT
|
122 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
123 |
-
if ckpt_file == "":
|
124 |
-
repo_name= "E2-TTS"
|
125 |
exp_name = "E2TTS_Base"
|
126 |
-
ckpt_step= 1200000
|
127 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
128 |
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
129 |
|
130 |
print(f"Using {model}...")
|
131 |
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
|
132 |
-
|
133 |
|
134 |
def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
|
135 |
-
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
|
136 |
if "voices" not in config:
|
137 |
voices = {"main": main_voice}
|
138 |
else:
|
139 |
voices = config["voices"]
|
140 |
voices["main"] = main_voice
|
141 |
for voice in voices:
|
142 |
-
voices[voice][
|
|
|
|
|
143 |
print("Voice:", voice)
|
144 |
-
print("Ref_audio:", voices[voice][
|
145 |
-
print("Ref_text:", voices[voice][
|
146 |
|
147 |
generated_audio_segments = []
|
148 |
-
reg1 = r
|
149 |
chunks = re.split(reg1, text_gen)
|
150 |
-
reg2 = r
|
151 |
for text in chunks:
|
152 |
match = re.match(reg2, text)
|
153 |
if match:
|
@@ -160,8 +151,8 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
|
|
160 |
voice = "main"
|
161 |
text = re.sub(reg2, "", text)
|
162 |
gen_text = text.strip()
|
163 |
-
ref_audio = voices[voice][
|
164 |
-
ref_text = voices[voice][
|
165 |
print(f"Voice: {voice}")
|
166 |
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
|
167 |
generated_audio_segments.append(audio)
|
|
|
44 |
"--vocab_file",
|
45 |
help="The vocab .txt",
|
46 |
)
|
47 |
+
parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
|
48 |
+
parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
parser.add_argument(
|
50 |
"-t",
|
51 |
"--gen_text",
|
|
|
88 |
ckpt_file = args.ckpt_file if args.ckpt_file else ""
|
89 |
vocab_file = args.vocab_file if args.vocab_file else ""
|
90 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
91 |
+
wave_path = Path(output_dir) / "out.wav"
|
92 |
+
spectrogram_path = Path(output_dir) / "out.png"
|
93 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
94 |
|
95 |
vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
|
|
|
99 |
if model == "F5-TTS":
|
100 |
model_cls = DiT
|
101 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
102 |
+
if ckpt_file == "":
|
103 |
+
repo_name = "F5-TTS"
|
104 |
exp_name = "F5TTS_Base"
|
105 |
+
ckpt_step = 1200000
|
106 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
107 |
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
108 |
|
109 |
elif model == "E2-TTS":
|
110 |
model_cls = UNetT
|
111 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
112 |
+
if ckpt_file == "":
|
113 |
+
repo_name = "E2-TTS"
|
114 |
exp_name = "E2TTS_Base"
|
115 |
+
ckpt_step = 1200000
|
116 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
117 |
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
118 |
|
119 |
print(f"Using {model}...")
|
120 |
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
|
121 |
+
|
122 |
|
123 |
def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
|
124 |
+
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
125 |
if "voices" not in config:
|
126 |
voices = {"main": main_voice}
|
127 |
else:
|
128 |
voices = config["voices"]
|
129 |
voices["main"] = main_voice
|
130 |
for voice in voices:
|
131 |
+
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
132 |
+
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
133 |
+
)
|
134 |
print("Voice:", voice)
|
135 |
+
print("Ref_audio:", voices[voice]["ref_audio"])
|
136 |
+
print("Ref_text:", voices[voice]["ref_text"])
|
137 |
|
138 |
generated_audio_segments = []
|
139 |
+
reg1 = r"(?=\[\w+\])"
|
140 |
chunks = re.split(reg1, text_gen)
|
141 |
+
reg2 = r"\[(\w+)\]"
|
142 |
for text in chunks:
|
143 |
match = re.match(reg2, text)
|
144 |
if match:
|
|
|
151 |
voice = "main"
|
152 |
text = re.sub(reg2, "", text)
|
153 |
gen_text = text.strip()
|
154 |
+
ref_audio = voices[voice]["ref_audio"]
|
155 |
+
ref_text = voices[voice]["ref_text"]
|
156 |
print(f"Voice: {voice}")
|
157 |
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
|
158 |
generated_audio_segments.append(audio)
|
model/__init__.py
CHANGED
@@ -5,3 +5,6 @@ from model.backbones.dit import DiT
|
|
5 |
from model.backbones.mmdit import MMDiT
|
6 |
|
7 |
from model.trainer import Trainer
|
|
|
|
|
|
|
|
5 |
from model.backbones.mmdit import MMDiT
|
6 |
|
7 |
from model.trainer import Trainer
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
model/backbones/dit.py
CHANGED
@@ -21,14 +21,16 @@ from model.modules import (
|
|
21 |
ConvPositionEmbedding,
|
22 |
DiTBlock,
|
23 |
AdaLayerNormZero_Final,
|
24 |
-
precompute_freqs_cis,
|
|
|
25 |
)
|
26 |
|
27 |
|
28 |
# Text embedding
|
29 |
|
|
|
30 |
class TextEmbedding(nn.Module):
|
31 |
-
def __init__(self, text_num_embeds, text_dim, conv_layers
|
32 |
super().__init__()
|
33 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
34 |
|
@@ -36,20 +38,22 @@ class TextEmbedding(nn.Module):
|
|
36 |
self.extra_modeling = True
|
37 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
38 |
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
39 |
-
self.text_blocks = nn.Sequential(
|
|
|
|
|
40 |
else:
|
41 |
self.extra_modeling = False
|
42 |
|
43 |
-
def forward(self, text: int[
|
44 |
batch, text_len = text.shape[0], text.shape[1]
|
45 |
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
46 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
47 |
-
text = F.pad(text, (0, seq_len - text_len), value
|
48 |
|
49 |
if drop_text: # cfg for text
|
50 |
text = torch.zeros_like(text)
|
51 |
|
52 |
-
text = self.text_embed(text)
|
53 |
|
54 |
# possible extra modeling
|
55 |
if self.extra_modeling:
|
@@ -67,88 +71,91 @@ class TextEmbedding(nn.Module):
|
|
67 |
|
68 |
# noised input audio and context mixing embedding
|
69 |
|
|
|
70 |
class InputEmbedding(nn.Module):
|
71 |
def __init__(self, mel_dim, text_dim, out_dim):
|
72 |
super().__init__()
|
73 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
74 |
-
self.conv_pos_embed = ConvPositionEmbedding(dim
|
75 |
|
76 |
-
def forward(self, x: float[
|
77 |
if drop_audio_cond: # cfg for cond audio
|
78 |
cond = torch.zeros_like(cond)
|
79 |
|
80 |
-
x = self.proj(torch.cat((x, cond, text_embed), dim
|
81 |
x = self.conv_pos_embed(x) + x
|
82 |
return x
|
83 |
-
|
84 |
|
85 |
# Transformer backbone using DiT blocks
|
86 |
|
|
|
87 |
class DiT(nn.Module):
|
88 |
-
def __init__(
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
):
|
93 |
super().__init__()
|
94 |
|
95 |
self.time_embed = TimestepEmbedding(dim)
|
96 |
if text_dim is None:
|
97 |
text_dim = mel_dim
|
98 |
-
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers
|
99 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
100 |
|
101 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
102 |
|
103 |
self.dim = dim
|
104 |
self.depth = depth
|
105 |
-
|
106 |
self.transformer_blocks = nn.ModuleList(
|
107 |
-
[
|
108 |
-
DiTBlock(
|
109 |
-
dim = dim,
|
110 |
-
heads = heads,
|
111 |
-
dim_head = dim_head,
|
112 |
-
ff_mult = ff_mult,
|
113 |
-
dropout = dropout
|
114 |
-
)
|
115 |
-
for _ in range(depth)
|
116 |
-
]
|
117 |
)
|
118 |
-
self.long_skip_connection = nn.Linear(dim * 2, dim, bias
|
119 |
-
|
120 |
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
121 |
self.proj_out = nn.Linear(dim, mel_dim)
|
122 |
|
123 |
def forward(
|
124 |
self,
|
125 |
-
x: float[
|
126 |
-
cond: float[
|
127 |
-
text: int[
|
128 |
-
time: float[
|
129 |
drop_audio_cond, # cfg for cond audio
|
130 |
-
drop_text,
|
131 |
-
mask: bool[
|
132 |
):
|
133 |
batch, seq_len = x.shape[0], x.shape[1]
|
134 |
if time.ndim == 0:
|
135 |
time = time.repeat(batch)
|
136 |
-
|
137 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
138 |
t = self.time_embed(time)
|
139 |
-
text_embed = self.text_embed(text, seq_len, drop_text
|
140 |
-
x = self.input_embed(x, cond, text_embed, drop_audio_cond
|
141 |
-
|
142 |
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
143 |
|
144 |
if self.long_skip_connection is not None:
|
145 |
residual = x
|
146 |
|
147 |
for block in self.transformer_blocks:
|
148 |
-
x = block(x, t, mask
|
149 |
|
150 |
if self.long_skip_connection is not None:
|
151 |
-
x = self.long_skip_connection(torch.cat((x, residual), dim
|
152 |
|
153 |
x = self.norm_out(x, t)
|
154 |
output = self.proj_out(x)
|
|
|
21 |
ConvPositionEmbedding,
|
22 |
DiTBlock,
|
23 |
AdaLayerNormZero_Final,
|
24 |
+
precompute_freqs_cis,
|
25 |
+
get_pos_embed_indices,
|
26 |
)
|
27 |
|
28 |
|
29 |
# Text embedding
|
30 |
|
31 |
+
|
32 |
class TextEmbedding(nn.Module):
|
33 |
+
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
34 |
super().__init__()
|
35 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
36 |
|
|
|
38 |
self.extra_modeling = True
|
39 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
40 |
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
41 |
+
self.text_blocks = nn.Sequential(
|
42 |
+
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
43 |
+
)
|
44 |
else:
|
45 |
self.extra_modeling = False
|
46 |
|
47 |
+
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
48 |
batch, text_len = text.shape[0], text.shape[1]
|
49 |
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
50 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
51 |
+
text = F.pad(text, (0, seq_len - text_len), value=0)
|
52 |
|
53 |
if drop_text: # cfg for text
|
54 |
text = torch.zeros_like(text)
|
55 |
|
56 |
+
text = self.text_embed(text) # b n -> b n d
|
57 |
|
58 |
# possible extra modeling
|
59 |
if self.extra_modeling:
|
|
|
71 |
|
72 |
# noised input audio and context mixing embedding
|
73 |
|
74 |
+
|
75 |
class InputEmbedding(nn.Module):
|
76 |
def __init__(self, mel_dim, text_dim, out_dim):
|
77 |
super().__init__()
|
78 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
79 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
80 |
|
81 |
+
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
82 |
if drop_audio_cond: # cfg for cond audio
|
83 |
cond = torch.zeros_like(cond)
|
84 |
|
85 |
+
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
86 |
x = self.conv_pos_embed(x) + x
|
87 |
return x
|
88 |
+
|
89 |
|
90 |
# Transformer backbone using DiT blocks
|
91 |
|
92 |
+
|
93 |
class DiT(nn.Module):
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
*,
|
97 |
+
dim,
|
98 |
+
depth=8,
|
99 |
+
heads=8,
|
100 |
+
dim_head=64,
|
101 |
+
dropout=0.1,
|
102 |
+
ff_mult=4,
|
103 |
+
mel_dim=100,
|
104 |
+
text_num_embeds=256,
|
105 |
+
text_dim=None,
|
106 |
+
conv_layers=0,
|
107 |
+
long_skip_connection=False,
|
108 |
):
|
109 |
super().__init__()
|
110 |
|
111 |
self.time_embed = TimestepEmbedding(dim)
|
112 |
if text_dim is None:
|
113 |
text_dim = mel_dim
|
114 |
+
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
115 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
116 |
|
117 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
118 |
|
119 |
self.dim = dim
|
120 |
self.depth = depth
|
121 |
+
|
122 |
self.transformer_blocks = nn.ModuleList(
|
123 |
+
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
)
|
125 |
+
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
126 |
+
|
127 |
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
128 |
self.proj_out = nn.Linear(dim, mel_dim)
|
129 |
|
130 |
def forward(
|
131 |
self,
|
132 |
+
x: float["b n d"], # nosied input audio # noqa: F722
|
133 |
+
cond: float["b n d"], # masked cond audio # noqa: F722
|
134 |
+
text: int["b nt"], # text # noqa: F722
|
135 |
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
136 |
drop_audio_cond, # cfg for cond audio
|
137 |
+
drop_text, # cfg for text
|
138 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
139 |
):
|
140 |
batch, seq_len = x.shape[0], x.shape[1]
|
141 |
if time.ndim == 0:
|
142 |
time = time.repeat(batch)
|
143 |
+
|
144 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
145 |
t = self.time_embed(time)
|
146 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
147 |
+
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
148 |
+
|
149 |
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
150 |
|
151 |
if self.long_skip_connection is not None:
|
152 |
residual = x
|
153 |
|
154 |
for block in self.transformer_blocks:
|
155 |
+
x = block(x, t, mask=mask, rope=rope)
|
156 |
|
157 |
if self.long_skip_connection is not None:
|
158 |
+
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
159 |
|
160 |
x = self.norm_out(x, t)
|
161 |
output = self.proj_out(x)
|
model/backbones/mmdit.py
CHANGED
@@ -19,12 +19,14 @@ from model.modules import (
|
|
19 |
ConvPositionEmbedding,
|
20 |
MMDiTBlock,
|
21 |
AdaLayerNormZero_Final,
|
22 |
-
precompute_freqs_cis,
|
|
|
23 |
)
|
24 |
|
25 |
|
26 |
# text embedding
|
27 |
|
|
|
28 |
class TextEmbedding(nn.Module):
|
29 |
def __init__(self, out_dim, text_num_embeds):
|
30 |
super().__init__()
|
@@ -33,7 +35,7 @@ class TextEmbedding(nn.Module):
|
|
33 |
self.precompute_max_pos = 1024
|
34 |
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
35 |
|
36 |
-
def forward(self, text: int[
|
37 |
text = text + 1
|
38 |
if drop_text:
|
39 |
text = torch.zeros_like(text)
|
@@ -52,27 +54,37 @@ class TextEmbedding(nn.Module):
|
|
52 |
|
53 |
# noised input & masked cond audio embedding
|
54 |
|
|
|
55 |
class AudioEmbedding(nn.Module):
|
56 |
def __init__(self, in_dim, out_dim):
|
57 |
super().__init__()
|
58 |
self.linear = nn.Linear(2 * in_dim, out_dim)
|
59 |
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
60 |
|
61 |
-
def forward(self, x: float[
|
62 |
if drop_audio_cond:
|
63 |
cond = torch.zeros_like(cond)
|
64 |
-
x = torch.cat((x, cond), dim
|
65 |
x = self.linear(x)
|
66 |
x = self.conv_pos_embed(x) + x
|
67 |
return x
|
68 |
-
|
69 |
|
70 |
# Transformer backbone using MM-DiT blocks
|
71 |
|
|
|
72 |
class MMDiT(nn.Module):
|
73 |
-
def __init__(
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
):
|
77 |
super().__init__()
|
78 |
|
@@ -84,16 +96,16 @@ class MMDiT(nn.Module):
|
|
84 |
|
85 |
self.dim = dim
|
86 |
self.depth = depth
|
87 |
-
|
88 |
self.transformer_blocks = nn.ModuleList(
|
89 |
[
|
90 |
MMDiTBlock(
|
91 |
-
dim
|
92 |
-
heads
|
93 |
-
dim_head
|
94 |
-
dropout
|
95 |
-
ff_mult
|
96 |
-
context_pre_only
|
97 |
)
|
98 |
for i in range(depth)
|
99 |
]
|
@@ -103,13 +115,13 @@ class MMDiT(nn.Module):
|
|
103 |
|
104 |
def forward(
|
105 |
self,
|
106 |
-
x: float[
|
107 |
-
cond: float[
|
108 |
-
text: int[
|
109 |
-
time: float[
|
110 |
drop_audio_cond, # cfg for cond audio
|
111 |
-
drop_text,
|
112 |
-
mask: bool[
|
113 |
):
|
114 |
batch = x.shape[0]
|
115 |
if time.ndim == 0:
|
@@ -117,16 +129,16 @@ class MMDiT(nn.Module):
|
|
117 |
|
118 |
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
119 |
t = self.time_embed(time)
|
120 |
-
c = self.text_embed(text, drop_text
|
121 |
-
x = self.audio_embed(x, cond, drop_audio_cond
|
122 |
|
123 |
seq_len = x.shape[1]
|
124 |
text_len = text.shape[1]
|
125 |
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
126 |
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
127 |
-
|
128 |
for block in self.transformer_blocks:
|
129 |
-
c, x = block(x, c, t, mask
|
130 |
|
131 |
x = self.norm_out(x, t)
|
132 |
output = self.proj_out(x)
|
|
|
19 |
ConvPositionEmbedding,
|
20 |
MMDiTBlock,
|
21 |
AdaLayerNormZero_Final,
|
22 |
+
precompute_freqs_cis,
|
23 |
+
get_pos_embed_indices,
|
24 |
)
|
25 |
|
26 |
|
27 |
# text embedding
|
28 |
|
29 |
+
|
30 |
class TextEmbedding(nn.Module):
|
31 |
def __init__(self, out_dim, text_num_embeds):
|
32 |
super().__init__()
|
|
|
35 |
self.precompute_max_pos = 1024
|
36 |
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
37 |
|
38 |
+
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
39 |
text = text + 1
|
40 |
if drop_text:
|
41 |
text = torch.zeros_like(text)
|
|
|
54 |
|
55 |
# noised input & masked cond audio embedding
|
56 |
|
57 |
+
|
58 |
class AudioEmbedding(nn.Module):
|
59 |
def __init__(self, in_dim, out_dim):
|
60 |
super().__init__()
|
61 |
self.linear = nn.Linear(2 * in_dim, out_dim)
|
62 |
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
63 |
|
64 |
+
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
|
65 |
if drop_audio_cond:
|
66 |
cond = torch.zeros_like(cond)
|
67 |
+
x = torch.cat((x, cond), dim=-1)
|
68 |
x = self.linear(x)
|
69 |
x = self.conv_pos_embed(x) + x
|
70 |
return x
|
71 |
+
|
72 |
|
73 |
# Transformer backbone using MM-DiT blocks
|
74 |
|
75 |
+
|
76 |
class MMDiT(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
*,
|
80 |
+
dim,
|
81 |
+
depth=8,
|
82 |
+
heads=8,
|
83 |
+
dim_head=64,
|
84 |
+
dropout=0.1,
|
85 |
+
ff_mult=4,
|
86 |
+
text_num_embeds=256,
|
87 |
+
mel_dim=100,
|
88 |
):
|
89 |
super().__init__()
|
90 |
|
|
|
96 |
|
97 |
self.dim = dim
|
98 |
self.depth = depth
|
99 |
+
|
100 |
self.transformer_blocks = nn.ModuleList(
|
101 |
[
|
102 |
MMDiTBlock(
|
103 |
+
dim=dim,
|
104 |
+
heads=heads,
|
105 |
+
dim_head=dim_head,
|
106 |
+
dropout=dropout,
|
107 |
+
ff_mult=ff_mult,
|
108 |
+
context_pre_only=i == depth - 1,
|
109 |
)
|
110 |
for i in range(depth)
|
111 |
]
|
|
|
115 |
|
116 |
def forward(
|
117 |
self,
|
118 |
+
x: float["b n d"], # nosied input audio # noqa: F722
|
119 |
+
cond: float["b n d"], # masked cond audio # noqa: F722
|
120 |
+
text: int["b nt"], # text # noqa: F722
|
121 |
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
122 |
drop_audio_cond, # cfg for cond audio
|
123 |
+
drop_text, # cfg for text
|
124 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
125 |
):
|
126 |
batch = x.shape[0]
|
127 |
if time.ndim == 0:
|
|
|
129 |
|
130 |
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
131 |
t = self.time_embed(time)
|
132 |
+
c = self.text_embed(text, drop_text=drop_text)
|
133 |
+
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
134 |
|
135 |
seq_len = x.shape[1]
|
136 |
text_len = text.shape[1]
|
137 |
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
138 |
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
139 |
+
|
140 |
for block in self.transformer_blocks:
|
141 |
+
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
|
142 |
|
143 |
x = self.norm_out(x, t)
|
144 |
output = self.proj_out(x)
|
model/backbones/unett.py
CHANGED
@@ -24,14 +24,16 @@ from model.modules import (
|
|
24 |
Attention,
|
25 |
AttnProcessor,
|
26 |
FeedForward,
|
27 |
-
precompute_freqs_cis,
|
|
|
28 |
)
|
29 |
|
30 |
|
31 |
# Text embedding
|
32 |
|
|
|
33 |
class TextEmbedding(nn.Module):
|
34 |
-
def __init__(self, text_num_embeds, text_dim, conv_layers
|
35 |
super().__init__()
|
36 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
37 |
|
@@ -39,20 +41,22 @@ class TextEmbedding(nn.Module):
|
|
39 |
self.extra_modeling = True
|
40 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
41 |
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
42 |
-
self.text_blocks = nn.Sequential(
|
|
|
|
|
43 |
else:
|
44 |
self.extra_modeling = False
|
45 |
|
46 |
-
def forward(self, text: int[
|
47 |
batch, text_len = text.shape[0], text.shape[1]
|
48 |
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
49 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
50 |
-
text = F.pad(text, (0, seq_len - text_len), value
|
51 |
|
52 |
if drop_text: # cfg for text
|
53 |
text = torch.zeros_like(text)
|
54 |
|
55 |
-
text = self.text_embed(text)
|
56 |
|
57 |
# possible extra modeling
|
58 |
if self.extra_modeling:
|
@@ -70,28 +74,40 @@ class TextEmbedding(nn.Module):
|
|
70 |
|
71 |
# noised input audio and context mixing embedding
|
72 |
|
|
|
73 |
class InputEmbedding(nn.Module):
|
74 |
def __init__(self, mel_dim, text_dim, out_dim):
|
75 |
super().__init__()
|
76 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
77 |
-
self.conv_pos_embed = ConvPositionEmbedding(dim
|
78 |
|
79 |
-
def forward(self, x: float[
|
80 |
if drop_audio_cond: # cfg for cond audio
|
81 |
cond = torch.zeros_like(cond)
|
82 |
|
83 |
-
x = self.proj(torch.cat((x, cond, text_embed), dim
|
84 |
x = self.conv_pos_embed(x) + x
|
85 |
return x
|
86 |
|
87 |
|
88 |
# Flat UNet Transformer backbone
|
89 |
|
|
|
90 |
class UNetT(nn.Module):
|
91 |
-
def __init__(
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
):
|
96 |
super().__init__()
|
97 |
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
|
@@ -99,7 +115,7 @@ class UNetT(nn.Module):
|
|
99 |
self.time_embed = TimestepEmbedding(dim)
|
100 |
if text_dim is None:
|
101 |
text_dim = mel_dim
|
102 |
-
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers
|
103 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
104 |
|
105 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
@@ -108,7 +124,7 @@ class UNetT(nn.Module):
|
|
108 |
|
109 |
self.dim = dim
|
110 |
self.skip_connect_type = skip_connect_type
|
111 |
-
needs_skip_proj = skip_connect_type ==
|
112 |
|
113 |
self.depth = depth
|
114 |
self.layers = nn.ModuleList([])
|
@@ -118,53 +134,57 @@ class UNetT(nn.Module):
|
|
118 |
|
119 |
attn_norm = RMSNorm(dim)
|
120 |
attn = Attention(
|
121 |
-
processor
|
122 |
-
dim
|
123 |
-
heads
|
124 |
-
dim_head
|
125 |
-
dropout
|
126 |
-
|
127 |
|
128 |
ff_norm = RMSNorm(dim)
|
129 |
-
ff = FeedForward(dim
|
130 |
-
|
131 |
-
skip_proj = nn.Linear(dim * 2, dim, bias
|
132 |
-
|
133 |
-
self.layers.append(
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
140 |
|
141 |
self.norm_out = RMSNorm(dim)
|
142 |
self.proj_out = nn.Linear(dim, mel_dim)
|
143 |
|
144 |
def forward(
|
145 |
self,
|
146 |
-
x: float[
|
147 |
-
cond: float[
|
148 |
-
text: int[
|
149 |
-
time: float[
|
150 |
drop_audio_cond, # cfg for cond audio
|
151 |
-
drop_text,
|
152 |
-
mask: bool[
|
153 |
):
|
154 |
batch, seq_len = x.shape[0], x.shape[1]
|
155 |
if time.ndim == 0:
|
156 |
time = time.repeat(batch)
|
157 |
-
|
158 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
159 |
t = self.time_embed(time)
|
160 |
-
text_embed = self.text_embed(text, seq_len, drop_text
|
161 |
-
x = self.input_embed(x, cond, text_embed, drop_audio_cond
|
162 |
|
163 |
# postfix time t to input x, [b n d] -> [b n+1 d]
|
164 |
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
|
165 |
if mask is not None:
|
166 |
mask = F.pad(mask, (1, 0), value=1)
|
167 |
-
|
168 |
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
|
169 |
|
170 |
# flat unet transformer
|
@@ -182,14 +202,14 @@ class UNetT(nn.Module):
|
|
182 |
|
183 |
if is_later_half:
|
184 |
skip = skips.pop()
|
185 |
-
if skip_connect_type ==
|
186 |
-
x = torch.cat((x, skip), dim
|
187 |
x = maybe_skip_proj(x)
|
188 |
-
elif skip_connect_type ==
|
189 |
x = x + skip
|
190 |
|
191 |
# attention and feedforward blocks
|
192 |
-
x = attn(attn_norm(x), rope
|
193 |
x = ff(ff_norm(x)) + x
|
194 |
|
195 |
assert len(skips) == 0
|
|
|
24 |
Attention,
|
25 |
AttnProcessor,
|
26 |
FeedForward,
|
27 |
+
precompute_freqs_cis,
|
28 |
+
get_pos_embed_indices,
|
29 |
)
|
30 |
|
31 |
|
32 |
# Text embedding
|
33 |
|
34 |
+
|
35 |
class TextEmbedding(nn.Module):
|
36 |
+
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
37 |
super().__init__()
|
38 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
39 |
|
|
|
41 |
self.extra_modeling = True
|
42 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
43 |
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
44 |
+
self.text_blocks = nn.Sequential(
|
45 |
+
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
46 |
+
)
|
47 |
else:
|
48 |
self.extra_modeling = False
|
49 |
|
50 |
+
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
51 |
batch, text_len = text.shape[0], text.shape[1]
|
52 |
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
53 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
54 |
+
text = F.pad(text, (0, seq_len - text_len), value=0)
|
55 |
|
56 |
if drop_text: # cfg for text
|
57 |
text = torch.zeros_like(text)
|
58 |
|
59 |
+
text = self.text_embed(text) # b n -> b n d
|
60 |
|
61 |
# possible extra modeling
|
62 |
if self.extra_modeling:
|
|
|
74 |
|
75 |
# noised input audio and context mixing embedding
|
76 |
|
77 |
+
|
78 |
class InputEmbedding(nn.Module):
|
79 |
def __init__(self, mel_dim, text_dim, out_dim):
|
80 |
super().__init__()
|
81 |
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
82 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
83 |
|
84 |
+
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
85 |
if drop_audio_cond: # cfg for cond audio
|
86 |
cond = torch.zeros_like(cond)
|
87 |
|
88 |
+
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
89 |
x = self.conv_pos_embed(x) + x
|
90 |
return x
|
91 |
|
92 |
|
93 |
# Flat UNet Transformer backbone
|
94 |
|
95 |
+
|
96 |
class UNetT(nn.Module):
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
*,
|
100 |
+
dim,
|
101 |
+
depth=8,
|
102 |
+
heads=8,
|
103 |
+
dim_head=64,
|
104 |
+
dropout=0.1,
|
105 |
+
ff_mult=4,
|
106 |
+
mel_dim=100,
|
107 |
+
text_num_embeds=256,
|
108 |
+
text_dim=None,
|
109 |
+
conv_layers=0,
|
110 |
+
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
111 |
):
|
112 |
super().__init__()
|
113 |
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
|
|
|
115 |
self.time_embed = TimestepEmbedding(dim)
|
116 |
if text_dim is None:
|
117 |
text_dim = mel_dim
|
118 |
+
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
119 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
120 |
|
121 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
|
124 |
|
125 |
self.dim = dim
|
126 |
self.skip_connect_type = skip_connect_type
|
127 |
+
needs_skip_proj = skip_connect_type == "concat"
|
128 |
|
129 |
self.depth = depth
|
130 |
self.layers = nn.ModuleList([])
|
|
|
134 |
|
135 |
attn_norm = RMSNorm(dim)
|
136 |
attn = Attention(
|
137 |
+
processor=AttnProcessor(),
|
138 |
+
dim=dim,
|
139 |
+
heads=heads,
|
140 |
+
dim_head=dim_head,
|
141 |
+
dropout=dropout,
|
142 |
+
)
|
143 |
|
144 |
ff_norm = RMSNorm(dim)
|
145 |
+
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
146 |
+
|
147 |
+
skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
|
148 |
+
|
149 |
+
self.layers.append(
|
150 |
+
nn.ModuleList(
|
151 |
+
[
|
152 |
+
skip_proj,
|
153 |
+
attn_norm,
|
154 |
+
attn,
|
155 |
+
ff_norm,
|
156 |
+
ff,
|
157 |
+
]
|
158 |
+
)
|
159 |
+
)
|
160 |
|
161 |
self.norm_out = RMSNorm(dim)
|
162 |
self.proj_out = nn.Linear(dim, mel_dim)
|
163 |
|
164 |
def forward(
|
165 |
self,
|
166 |
+
x: float["b n d"], # nosied input audio # noqa: F722
|
167 |
+
cond: float["b n d"], # masked cond audio # noqa: F722
|
168 |
+
text: int["b nt"], # text # noqa: F722
|
169 |
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
170 |
drop_audio_cond, # cfg for cond audio
|
171 |
+
drop_text, # cfg for text
|
172 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
173 |
):
|
174 |
batch, seq_len = x.shape[0], x.shape[1]
|
175 |
if time.ndim == 0:
|
176 |
time = time.repeat(batch)
|
177 |
+
|
178 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
179 |
t = self.time_embed(time)
|
180 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
181 |
+
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
182 |
|
183 |
# postfix time t to input x, [b n d] -> [b n+1 d]
|
184 |
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
|
185 |
if mask is not None:
|
186 |
mask = F.pad(mask, (1, 0), value=1)
|
187 |
+
|
188 |
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
|
189 |
|
190 |
# flat unet transformer
|
|
|
202 |
|
203 |
if is_later_half:
|
204 |
skip = skips.pop()
|
205 |
+
if skip_connect_type == "concat":
|
206 |
+
x = torch.cat((x, skip), dim=-1)
|
207 |
x = maybe_skip_proj(x)
|
208 |
+
elif skip_connect_type == "add":
|
209 |
x = x + skip
|
210 |
|
211 |
# attention and feedforward blocks
|
212 |
+
x = attn(attn_norm(x), rope=rope, mask=mask) + x
|
213 |
x = ff(ff_norm(x)) + x
|
214 |
|
215 |
assert len(skips) == 0
|
model/cfm.py
CHANGED
@@ -20,29 +20,32 @@ from torchdiffeq import odeint
|
|
20 |
|
21 |
from model.modules import MelSpec
|
22 |
from model.utils import (
|
23 |
-
default,
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
class CFM(nn.Module):
|
30 |
def __init__(
|
31 |
self,
|
32 |
transformer: nn.Module,
|
33 |
-
sigma
|
34 |
odeint_kwargs: dict = dict(
|
35 |
# atol = 1e-5,
|
36 |
# rtol = 1e-5,
|
37 |
-
method
|
38 |
),
|
39 |
-
audio_drop_prob
|
40 |
-
cond_drop_prob
|
41 |
-
num_channels
|
42 |
mel_spec_module: nn.Module | None = None,
|
43 |
mel_spec_kwargs: dict = dict(),
|
44 |
-
frac_lengths_mask: tuple[float, float] = (0.7, 1.),
|
45 |
-
vocab_char_map: dict[str:
|
46 |
):
|
47 |
super().__init__()
|
48 |
|
@@ -78,21 +81,21 @@ class CFM(nn.Module):
|
|
78 |
@torch.no_grad()
|
79 |
def sample(
|
80 |
self,
|
81 |
-
cond: float[
|
82 |
-
text: int[
|
83 |
-
duration: int | int[
|
84 |
*,
|
85 |
-
lens: int[
|
86 |
-
steps
|
87 |
-
cfg_strength
|
88 |
-
sway_sampling_coef
|
89 |
seed: int | None = None,
|
90 |
-
max_duration
|
91 |
-
vocoder: Callable[[float[
|
92 |
-
no_ref_audio
|
93 |
-
duplicate_test
|
94 |
-
t_inter
|
95 |
-
edit_mask
|
96 |
):
|
97 |
self.eval()
|
98 |
|
@@ -108,7 +111,7 @@ class CFM(nn.Module):
|
|
108 |
|
109 |
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
110 |
if not exists(lens):
|
111 |
-
lens = torch.full((batch,), cond_seq_len, device
|
112 |
|
113 |
# text
|
114 |
|
@@ -120,8 +123,8 @@ class CFM(nn.Module):
|
|
120 |
assert text.shape[0] == batch
|
121 |
|
122 |
if exists(text):
|
123 |
-
text_lens = (text != -1).sum(dim
|
124 |
-
lens = torch.maximum(text_lens, lens)
|
125 |
|
126 |
# duration
|
127 |
|
@@ -130,20 +133,22 @@ class CFM(nn.Module):
|
|
130 |
cond_mask = cond_mask & edit_mask
|
131 |
|
132 |
if isinstance(duration, int):
|
133 |
-
duration = torch.full((batch,), duration, device
|
134 |
|
135 |
-
duration = torch.maximum(lens + 1, duration)
|
136 |
-
duration = duration.clamp(max
|
137 |
max_duration = duration.amax()
|
138 |
-
|
139 |
# duplicate test corner for inner time step oberservation
|
140 |
if duplicate_test:
|
141 |
-
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value
|
142 |
-
|
143 |
-
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value
|
144 |
-
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value
|
145 |
cond_mask = cond_mask.unsqueeze(-1)
|
146 |
-
step_cond = torch.where(
|
|
|
|
|
147 |
|
148 |
if batch > 1:
|
149 |
mask = lens_to_mask(duration)
|
@@ -161,11 +166,15 @@ class CFM(nn.Module):
|
|
161 |
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
162 |
|
163 |
# predict flow
|
164 |
-
pred = self.transformer(
|
|
|
|
|
165 |
if cfg_strength < 1e-5:
|
166 |
return pred
|
167 |
-
|
168 |
-
null_pred = self.transformer(
|
|
|
|
|
169 |
return pred + (pred - null_pred) * cfg_strength
|
170 |
|
171 |
# noise input
|
@@ -175,8 +184,8 @@ class CFM(nn.Module):
|
|
175 |
for dur in duration:
|
176 |
if exists(seed):
|
177 |
torch.manual_seed(seed)
|
178 |
-
y0.append(torch.randn(dur, self.num_channels, device
|
179 |
-
y0 = pad_sequence(y0, padding_value
|
180 |
|
181 |
t_start = 0
|
182 |
|
@@ -186,12 +195,12 @@ class CFM(nn.Module):
|
|
186 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
187 |
steps = int(steps * (1 - t_start))
|
188 |
|
189 |
-
t = torch.linspace(t_start, 1, steps, device
|
190 |
if sway_sampling_coef is not None:
|
191 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
192 |
|
193 |
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
194 |
-
|
195 |
sampled = trajectory[-1]
|
196 |
out = sampled
|
197 |
out = torch.where(cond_mask, cond, out)
|
@@ -204,10 +213,10 @@ class CFM(nn.Module):
|
|
204 |
|
205 |
def forward(
|
206 |
self,
|
207 |
-
inp: float[
|
208 |
-
text: int[
|
209 |
*,
|
210 |
-
lens: int[
|
211 |
noise_scheduler: str | None = None,
|
212 |
):
|
213 |
# handle raw wave
|
@@ -216,7 +225,7 @@ class CFM(nn.Module):
|
|
216 |
inp = inp.permute(0, 2, 1)
|
217 |
assert inp.shape[-1] == self.num_channels
|
218 |
|
219 |
-
batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
|
220 |
|
221 |
# handle text as string
|
222 |
if isinstance(text, list):
|
@@ -228,12 +237,12 @@ class CFM(nn.Module):
|
|
228 |
|
229 |
# lens and mask
|
230 |
if not exists(lens):
|
231 |
-
lens = torch.full((batch,), seq_len, device
|
232 |
-
|
233 |
-
mask = lens_to_mask(lens, length
|
234 |
|
235 |
# get a random span to mask out for training conditionally
|
236 |
-
frac_lengths = torch.zeros((batch,), device
|
237 |
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
238 |
|
239 |
if exists(mask):
|
@@ -246,7 +255,7 @@ class CFM(nn.Module):
|
|
246 |
x0 = torch.randn_like(x1)
|
247 |
|
248 |
# time step
|
249 |
-
time = torch.rand((batch,), dtype
|
250 |
# TODO. noise_scheduler
|
251 |
|
252 |
# sample xt (φ_t(x) in the paper)
|
@@ -255,10 +264,7 @@ class CFM(nn.Module):
|
|
255 |
flow = x1 - x0
|
256 |
|
257 |
# only predict what is within the random mask span for infilling
|
258 |
-
cond = torch.where(
|
259 |
-
rand_span_mask[..., None],
|
260 |
-
torch.zeros_like(x1), x1
|
261 |
-
)
|
262 |
|
263 |
# transformer and cfg training with a drop rate
|
264 |
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
|
@@ -267,13 +273,15 @@ class CFM(nn.Module):
|
|
267 |
drop_text = True
|
268 |
else:
|
269 |
drop_text = False
|
270 |
-
|
271 |
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
|
272 |
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
273 |
-
pred = self.transformer(
|
|
|
|
|
274 |
|
275 |
# flow matching loss
|
276 |
-
loss = F.mse_loss(pred, flow, reduction
|
277 |
loss = loss[rand_span_mask]
|
278 |
|
279 |
return loss.mean(), cond, pred
|
|
|
20 |
|
21 |
from model.modules import MelSpec
|
22 |
from model.utils import (
|
23 |
+
default,
|
24 |
+
exists,
|
25 |
+
list_str_to_idx,
|
26 |
+
list_str_to_tensor,
|
27 |
+
lens_to_mask,
|
28 |
+
mask_from_frac_lengths,
|
29 |
+
)
|
30 |
|
31 |
|
32 |
class CFM(nn.Module):
|
33 |
def __init__(
|
34 |
self,
|
35 |
transformer: nn.Module,
|
36 |
+
sigma=0.0,
|
37 |
odeint_kwargs: dict = dict(
|
38 |
# atol = 1e-5,
|
39 |
# rtol = 1e-5,
|
40 |
+
method="euler" # 'midpoint'
|
41 |
),
|
42 |
+
audio_drop_prob=0.3,
|
43 |
+
cond_drop_prob=0.2,
|
44 |
+
num_channels=None,
|
45 |
mel_spec_module: nn.Module | None = None,
|
46 |
mel_spec_kwargs: dict = dict(),
|
47 |
+
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
|
48 |
+
vocab_char_map: dict[str:int] | None = None,
|
49 |
):
|
50 |
super().__init__()
|
51 |
|
|
|
81 |
@torch.no_grad()
|
82 |
def sample(
|
83 |
self,
|
84 |
+
cond: float["b n d"] | float["b nw"], # noqa: F722
|
85 |
+
text: int["b nt"] | list[str], # noqa: F722
|
86 |
+
duration: int | int["b"], # noqa: F821
|
87 |
*,
|
88 |
+
lens: int["b"] | None = None, # noqa: F821
|
89 |
+
steps=32,
|
90 |
+
cfg_strength=1.0,
|
91 |
+
sway_sampling_coef=None,
|
92 |
seed: int | None = None,
|
93 |
+
max_duration=4096,
|
94 |
+
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
95 |
+
no_ref_audio=False,
|
96 |
+
duplicate_test=False,
|
97 |
+
t_inter=0.1,
|
98 |
+
edit_mask=None,
|
99 |
):
|
100 |
self.eval()
|
101 |
|
|
|
111 |
|
112 |
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
113 |
if not exists(lens):
|
114 |
+
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
|
115 |
|
116 |
# text
|
117 |
|
|
|
123 |
assert text.shape[0] == batch
|
124 |
|
125 |
if exists(text):
|
126 |
+
text_lens = (text != -1).sum(dim=-1)
|
127 |
+
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
|
128 |
|
129 |
# duration
|
130 |
|
|
|
133 |
cond_mask = cond_mask & edit_mask
|
134 |
|
135 |
if isinstance(duration, int):
|
136 |
+
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
137 |
|
138 |
+
duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
|
139 |
+
duration = duration.clamp(max=max_duration)
|
140 |
max_duration = duration.amax()
|
141 |
+
|
142 |
# duplicate test corner for inner time step oberservation
|
143 |
if duplicate_test:
|
144 |
+
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
|
145 |
+
|
146 |
+
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
147 |
+
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
|
148 |
cond_mask = cond_mask.unsqueeze(-1)
|
149 |
+
step_cond = torch.where(
|
150 |
+
cond_mask, cond, torch.zeros_like(cond)
|
151 |
+
) # allow direct control (cut cond audio) with lens passed in
|
152 |
|
153 |
if batch > 1:
|
154 |
mask = lens_to_mask(duration)
|
|
|
166 |
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
167 |
|
168 |
# predict flow
|
169 |
+
pred = self.transformer(
|
170 |
+
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
|
171 |
+
)
|
172 |
if cfg_strength < 1e-5:
|
173 |
return pred
|
174 |
+
|
175 |
+
null_pred = self.transformer(
|
176 |
+
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
|
177 |
+
)
|
178 |
return pred + (pred - null_pred) * cfg_strength
|
179 |
|
180 |
# noise input
|
|
|
184 |
for dur in duration:
|
185 |
if exists(seed):
|
186 |
torch.manual_seed(seed)
|
187 |
+
y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
|
188 |
+
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
|
189 |
|
190 |
t_start = 0
|
191 |
|
|
|
195 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
196 |
steps = int(steps * (1 - t_start))
|
197 |
|
198 |
+
t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
|
199 |
if sway_sampling_coef is not None:
|
200 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
201 |
|
202 |
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
203 |
+
|
204 |
sampled = trajectory[-1]
|
205 |
out = sampled
|
206 |
out = torch.where(cond_mask, cond, out)
|
|
|
213 |
|
214 |
def forward(
|
215 |
self,
|
216 |
+
inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
|
217 |
+
text: int["b nt"] | list[str], # noqa: F722
|
218 |
*,
|
219 |
+
lens: int["b"] | None = None, # noqa: F821
|
220 |
noise_scheduler: str | None = None,
|
221 |
):
|
222 |
# handle raw wave
|
|
|
225 |
inp = inp.permute(0, 2, 1)
|
226 |
assert inp.shape[-1] == self.num_channels
|
227 |
|
228 |
+
batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
|
229 |
|
230 |
# handle text as string
|
231 |
if isinstance(text, list):
|
|
|
237 |
|
238 |
# lens and mask
|
239 |
if not exists(lens):
|
240 |
+
lens = torch.full((batch,), seq_len, device=device)
|
241 |
+
|
242 |
+
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
|
243 |
|
244 |
# get a random span to mask out for training conditionally
|
245 |
+
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
|
246 |
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
247 |
|
248 |
if exists(mask):
|
|
|
255 |
x0 = torch.randn_like(x1)
|
256 |
|
257 |
# time step
|
258 |
+
time = torch.rand((batch,), dtype=dtype, device=self.device)
|
259 |
# TODO. noise_scheduler
|
260 |
|
261 |
# sample xt (φ_t(x) in the paper)
|
|
|
264 |
flow = x1 - x0
|
265 |
|
266 |
# only predict what is within the random mask span for infilling
|
267 |
+
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
|
|
|
|
|
|
|
268 |
|
269 |
# transformer and cfg training with a drop rate
|
270 |
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
|
|
|
273 |
drop_text = True
|
274 |
else:
|
275 |
drop_text = False
|
276 |
+
|
277 |
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
|
278 |
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
279 |
+
pred = self.transformer(
|
280 |
+
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
|
281 |
+
)
|
282 |
|
283 |
# flow matching loss
|
284 |
+
loss = F.mse_loss(pred, flow, reduction="none")
|
285 |
loss = loss[rand_span_mask]
|
286 |
|
287 |
return loss.mean(), cond, pred
|
model/dataset.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6 |
import torch.nn.functional as F
|
7 |
from torch.utils.data import Dataset, Sampler
|
8 |
import torchaudio
|
9 |
-
from datasets import
|
10 |
from datasets import Dataset as Dataset_
|
11 |
|
12 |
from model.modules import MelSpec
|
@@ -16,53 +16,55 @@ class HFDataset(Dataset):
|
|
16 |
def __init__(
|
17 |
self,
|
18 |
hf_dataset: Dataset,
|
19 |
-
target_sample_rate
|
20 |
-
n_mel_channels
|
21 |
-
hop_length
|
22 |
):
|
23 |
self.data = hf_dataset
|
24 |
self.target_sample_rate = target_sample_rate
|
25 |
self.hop_length = hop_length
|
26 |
-
self.mel_spectrogram = MelSpec(
|
27 |
-
|
|
|
|
|
28 |
def get_frame_len(self, index):
|
29 |
row = self.data[index]
|
30 |
-
audio = row[
|
31 |
-
sample_rate = row[
|
32 |
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
|
33 |
|
34 |
def __len__(self):
|
35 |
return len(self.data)
|
36 |
-
|
37 |
def __getitem__(self, index):
|
38 |
row = self.data[index]
|
39 |
-
audio = row[
|
40 |
|
41 |
# logger.info(f"Audio shape: {audio.shape}")
|
42 |
|
43 |
-
sample_rate = row[
|
44 |
duration = audio.shape[-1] / sample_rate
|
45 |
|
46 |
if duration > 30 or duration < 0.3:
|
47 |
return self.__getitem__((index + 1) % len(self.data))
|
48 |
-
|
49 |
audio_tensor = torch.from_numpy(audio).float()
|
50 |
-
|
51 |
if sample_rate != self.target_sample_rate:
|
52 |
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
53 |
audio_tensor = resampler(audio_tensor)
|
54 |
-
|
55 |
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
|
56 |
-
|
57 |
mel_spec = self.mel_spectrogram(audio_tensor)
|
58 |
-
|
59 |
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
|
60 |
-
|
61 |
-
text = row[
|
62 |
-
|
63 |
return dict(
|
64 |
-
mel_spec
|
65 |
-
text
|
66 |
)
|
67 |
|
68 |
|
@@ -70,11 +72,11 @@ class CustomDataset(Dataset):
|
|
70 |
def __init__(
|
71 |
self,
|
72 |
custom_dataset: Dataset,
|
73 |
-
durations
|
74 |
-
target_sample_rate
|
75 |
-
hop_length
|
76 |
-
n_mel_channels
|
77 |
-
preprocessed_mel
|
78 |
):
|
79 |
self.data = custom_dataset
|
80 |
self.durations = durations
|
@@ -82,16 +84,20 @@ class CustomDataset(Dataset):
|
|
82 |
self.hop_length = hop_length
|
83 |
self.preprocessed_mel = preprocessed_mel
|
84 |
if not preprocessed_mel:
|
85 |
-
self.mel_spectrogram = MelSpec(
|
|
|
|
|
86 |
|
87 |
def get_frame_len(self, index):
|
88 |
-
if
|
|
|
|
|
89 |
return self.durations[index] * self.target_sample_rate / self.hop_length
|
90 |
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
|
91 |
-
|
92 |
def __len__(self):
|
93 |
return len(self.data)
|
94 |
-
|
95 |
def __getitem__(self, index):
|
96 |
row = self.data[index]
|
97 |
audio_path = row["audio_path"]
|
@@ -108,45 +114,52 @@ class CustomDataset(Dataset):
|
|
108 |
|
109 |
if duration > 30 or duration < 0.3:
|
110 |
return self.__getitem__((index + 1) % len(self.data))
|
111 |
-
|
112 |
if source_sample_rate != self.target_sample_rate:
|
113 |
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
|
114 |
audio = resampler(audio)
|
115 |
-
|
116 |
mel_spec = self.mel_spectrogram(audio)
|
117 |
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
|
118 |
-
|
119 |
return dict(
|
120 |
-
mel_spec
|
121 |
-
text
|
122 |
)
|
123 |
-
|
124 |
|
125 |
# Dynamic Batch Sampler
|
126 |
|
|
|
127 |
class DynamicBatchSampler(Sampler[list[int]]):
|
128 |
-
"""
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
"""
|
134 |
|
135 |
-
def __init__(
|
|
|
|
|
136 |
self.sampler = sampler
|
137 |
self.frames_threshold = frames_threshold
|
138 |
self.max_samples = max_samples
|
139 |
|
140 |
indices, batches = [], []
|
141 |
data_source = self.sampler.data_source
|
142 |
-
|
143 |
-
for idx in tqdm(
|
|
|
|
|
144 |
indices.append((idx, data_source.get_frame_len(idx)))
|
145 |
-
indices.sort(key=lambda elem
|
146 |
|
147 |
batch = []
|
148 |
batch_frames = 0
|
149 |
-
for idx, frame_len in tqdm(
|
|
|
|
|
150 |
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
|
151 |
batch.append(idx)
|
152 |
batch_frames += frame_len
|
@@ -182,76 +195,86 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
182 |
|
183 |
# Load dataset
|
184 |
|
|
|
185 |
def load_dataset(
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
|
194 |
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
|
195 |
-
|
196 |
-
|
197 |
print("Loading dataset ...")
|
198 |
|
199 |
if dataset_type == "CustomDataset":
|
200 |
if audio_type == "raw":
|
201 |
try:
|
202 |
train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
|
203 |
-
except:
|
204 |
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
|
205 |
preprocessed_mel = False
|
206 |
elif audio_type == "mel":
|
207 |
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
|
208 |
preprocessed_mel = True
|
209 |
-
with open(f"data/{dataset_name}_{tokenizer}/duration.json",
|
210 |
data_dict = json.load(f)
|
211 |
durations = data_dict["duration"]
|
212 |
-
train_dataset = CustomDataset(
|
213 |
-
|
|
|
|
|
214 |
elif dataset_type == "CustomDatasetPath":
|
215 |
try:
|
216 |
train_dataset = load_from_disk(f"{dataset_name}/raw")
|
217 |
-
except:
|
218 |
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
|
219 |
-
|
220 |
-
with open(f"{dataset_name}/duration.json",
|
221 |
data_dict = json.load(f)
|
222 |
durations = data_dict["duration"]
|
223 |
-
train_dataset = CustomDataset(
|
224 |
-
|
|
|
|
|
225 |
elif dataset_type == "HFDataset":
|
226 |
-
print(
|
227 |
-
|
|
|
|
|
228 |
pre, post = dataset_name.split("_")
|
229 |
-
train_dataset = HFDataset(
|
|
|
|
|
230 |
|
231 |
return train_dataset
|
232 |
|
233 |
|
234 |
# collation
|
235 |
|
|
|
236 |
def collate_fn(batch):
|
237 |
-
mel_specs = [item[
|
238 |
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
|
239 |
max_mel_length = mel_lengths.amax()
|
240 |
|
241 |
padded_mel_specs = []
|
242 |
for spec in mel_specs: # TODO. maybe records mask for attention here
|
243 |
padding = (0, max_mel_length - spec.size(-1))
|
244 |
-
padded_spec = F.pad(spec, padding, value
|
245 |
padded_mel_specs.append(padded_spec)
|
246 |
-
|
247 |
mel_specs = torch.stack(padded_mel_specs)
|
248 |
|
249 |
-
text = [item[
|
250 |
text_lengths = torch.LongTensor([len(item) for item in text])
|
251 |
|
252 |
return dict(
|
253 |
-
mel
|
254 |
-
mel_lengths
|
255 |
-
text
|
256 |
-
text_lengths
|
257 |
)
|
|
|
6 |
import torch.nn.functional as F
|
7 |
from torch.utils.data import Dataset, Sampler
|
8 |
import torchaudio
|
9 |
+
from datasets import load_from_disk
|
10 |
from datasets import Dataset as Dataset_
|
11 |
|
12 |
from model.modules import MelSpec
|
|
|
16 |
def __init__(
|
17 |
self,
|
18 |
hf_dataset: Dataset,
|
19 |
+
target_sample_rate=24_000,
|
20 |
+
n_mel_channels=100,
|
21 |
+
hop_length=256,
|
22 |
):
|
23 |
self.data = hf_dataset
|
24 |
self.target_sample_rate = target_sample_rate
|
25 |
self.hop_length = hop_length
|
26 |
+
self.mel_spectrogram = MelSpec(
|
27 |
+
target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
|
28 |
+
)
|
29 |
+
|
30 |
def get_frame_len(self, index):
|
31 |
row = self.data[index]
|
32 |
+
audio = row["audio"]["array"]
|
33 |
+
sample_rate = row["audio"]["sampling_rate"]
|
34 |
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
|
35 |
|
36 |
def __len__(self):
|
37 |
return len(self.data)
|
38 |
+
|
39 |
def __getitem__(self, index):
|
40 |
row = self.data[index]
|
41 |
+
audio = row["audio"]["array"]
|
42 |
|
43 |
# logger.info(f"Audio shape: {audio.shape}")
|
44 |
|
45 |
+
sample_rate = row["audio"]["sampling_rate"]
|
46 |
duration = audio.shape[-1] / sample_rate
|
47 |
|
48 |
if duration > 30 or duration < 0.3:
|
49 |
return self.__getitem__((index + 1) % len(self.data))
|
50 |
+
|
51 |
audio_tensor = torch.from_numpy(audio).float()
|
52 |
+
|
53 |
if sample_rate != self.target_sample_rate:
|
54 |
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
55 |
audio_tensor = resampler(audio_tensor)
|
56 |
+
|
57 |
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
|
58 |
+
|
59 |
mel_spec = self.mel_spectrogram(audio_tensor)
|
60 |
+
|
61 |
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
|
62 |
+
|
63 |
+
text = row["text"]
|
64 |
+
|
65 |
return dict(
|
66 |
+
mel_spec=mel_spec,
|
67 |
+
text=text,
|
68 |
)
|
69 |
|
70 |
|
|
|
72 |
def __init__(
|
73 |
self,
|
74 |
custom_dataset: Dataset,
|
75 |
+
durations=None,
|
76 |
+
target_sample_rate=24_000,
|
77 |
+
hop_length=256,
|
78 |
+
n_mel_channels=100,
|
79 |
+
preprocessed_mel=False,
|
80 |
):
|
81 |
self.data = custom_dataset
|
82 |
self.durations = durations
|
|
|
84 |
self.hop_length = hop_length
|
85 |
self.preprocessed_mel = preprocessed_mel
|
86 |
if not preprocessed_mel:
|
87 |
+
self.mel_spectrogram = MelSpec(
|
88 |
+
target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels
|
89 |
+
)
|
90 |
|
91 |
def get_frame_len(self, index):
|
92 |
+
if (
|
93 |
+
self.durations is not None
|
94 |
+
): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
|
95 |
return self.durations[index] * self.target_sample_rate / self.hop_length
|
96 |
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
|
97 |
+
|
98 |
def __len__(self):
|
99 |
return len(self.data)
|
100 |
+
|
101 |
def __getitem__(self, index):
|
102 |
row = self.data[index]
|
103 |
audio_path = row["audio_path"]
|
|
|
114 |
|
115 |
if duration > 30 or duration < 0.3:
|
116 |
return self.__getitem__((index + 1) % len(self.data))
|
117 |
+
|
118 |
if source_sample_rate != self.target_sample_rate:
|
119 |
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
|
120 |
audio = resampler(audio)
|
121 |
+
|
122 |
mel_spec = self.mel_spectrogram(audio)
|
123 |
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
|
124 |
+
|
125 |
return dict(
|
126 |
+
mel_spec=mel_spec,
|
127 |
+
text=text,
|
128 |
)
|
129 |
+
|
130 |
|
131 |
# Dynamic Batch Sampler
|
132 |
|
133 |
+
|
134 |
class DynamicBatchSampler(Sampler[list[int]]):
|
135 |
+
"""Extension of Sampler that will do the following:
|
136 |
+
1. Change the batch size (essentially number of sequences)
|
137 |
+
in a batch to ensure that the total number of frames are less
|
138 |
+
than a certain threshold.
|
139 |
+
2. Make sure the padding efficiency in the batch is high.
|
140 |
"""
|
141 |
|
142 |
+
def __init__(
|
143 |
+
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
|
144 |
+
):
|
145 |
self.sampler = sampler
|
146 |
self.frames_threshold = frames_threshold
|
147 |
self.max_samples = max_samples
|
148 |
|
149 |
indices, batches = [], []
|
150 |
data_source = self.sampler.data_source
|
151 |
+
|
152 |
+
for idx in tqdm(
|
153 |
+
self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
|
154 |
+
):
|
155 |
indices.append((idx, data_source.get_frame_len(idx)))
|
156 |
+
indices.sort(key=lambda elem: elem[1])
|
157 |
|
158 |
batch = []
|
159 |
batch_frames = 0
|
160 |
+
for idx, frame_len in tqdm(
|
161 |
+
indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
|
162 |
+
):
|
163 |
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
|
164 |
batch.append(idx)
|
165 |
batch_frames += frame_len
|
|
|
195 |
|
196 |
# Load dataset
|
197 |
|
198 |
+
|
199 |
def load_dataset(
|
200 |
+
dataset_name: str,
|
201 |
+
tokenizer: str = "pinyin",
|
202 |
+
dataset_type: str = "CustomDataset",
|
203 |
+
audio_type: str = "raw",
|
204 |
+
mel_spec_kwargs: dict = dict(),
|
205 |
+
) -> CustomDataset | HFDataset:
|
206 |
+
"""
|
207 |
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
|
208 |
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
|
209 |
+
"""
|
210 |
+
|
211 |
print("Loading dataset ...")
|
212 |
|
213 |
if dataset_type == "CustomDataset":
|
214 |
if audio_type == "raw":
|
215 |
try:
|
216 |
train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
|
217 |
+
except: # noqa: E722
|
218 |
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
|
219 |
preprocessed_mel = False
|
220 |
elif audio_type == "mel":
|
221 |
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
|
222 |
preprocessed_mel = True
|
223 |
+
with open(f"data/{dataset_name}_{tokenizer}/duration.json", "r", encoding="utf-8") as f:
|
224 |
data_dict = json.load(f)
|
225 |
durations = data_dict["duration"]
|
226 |
+
train_dataset = CustomDataset(
|
227 |
+
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
|
228 |
+
)
|
229 |
+
|
230 |
elif dataset_type == "CustomDatasetPath":
|
231 |
try:
|
232 |
train_dataset = load_from_disk(f"{dataset_name}/raw")
|
233 |
+
except: # noqa: E722
|
234 |
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
|
235 |
+
|
236 |
+
with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
|
237 |
data_dict = json.load(f)
|
238 |
durations = data_dict["duration"]
|
239 |
+
train_dataset = CustomDataset(
|
240 |
+
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
|
241 |
+
)
|
242 |
+
|
243 |
elif dataset_type == "HFDataset":
|
244 |
+
print(
|
245 |
+
"Should manually modify the path of huggingface dataset to your need.\n"
|
246 |
+
+ "May also the corresponding script cuz different dataset may have different format."
|
247 |
+
)
|
248 |
pre, post = dataset_name.split("_")
|
249 |
+
train_dataset = HFDataset(
|
250 |
+
load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),
|
251 |
+
)
|
252 |
|
253 |
return train_dataset
|
254 |
|
255 |
|
256 |
# collation
|
257 |
|
258 |
+
|
259 |
def collate_fn(batch):
|
260 |
+
mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
|
261 |
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
|
262 |
max_mel_length = mel_lengths.amax()
|
263 |
|
264 |
padded_mel_specs = []
|
265 |
for spec in mel_specs: # TODO. maybe records mask for attention here
|
266 |
padding = (0, max_mel_length - spec.size(-1))
|
267 |
+
padded_spec = F.pad(spec, padding, value=0)
|
268 |
padded_mel_specs.append(padded_spec)
|
269 |
+
|
270 |
mel_specs = torch.stack(padded_mel_specs)
|
271 |
|
272 |
+
text = [item["text"] for item in batch]
|
273 |
text_lengths = torch.LongTensor([len(item) for item in text])
|
274 |
|
275 |
return dict(
|
276 |
+
mel=mel_specs,
|
277 |
+
mel_lengths=mel_lengths,
|
278 |
+
text=text,
|
279 |
+
text_lengths=text_lengths,
|
280 |
)
|
model/ecapa_tdnn.py
CHANGED
@@ -9,13 +9,14 @@ import torch.nn as nn
|
|
9 |
import torch.nn.functional as F
|
10 |
|
11 |
|
12 |
-
|
13 |
-
|
|
|
14 |
|
15 |
class Res2Conv1dReluBn(nn.Module):
|
16 |
-
|
17 |
in_channels == out_channels == channels
|
18 |
-
|
19 |
|
20 |
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
|
21 |
super().__init__()
|
@@ -51,8 +52,9 @@ class Res2Conv1dReluBn(nn.Module):
|
|
51 |
return out
|
52 |
|
53 |
|
54 |
-
|
55 |
-
|
|
|
56 |
|
57 |
class Conv1dReluBn(nn.Module):
|
58 |
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
|
@@ -64,8 +66,9 @@ class Conv1dReluBn(nn.Module):
|
|
64 |
return self.bn(F.relu(self.conv(x)))
|
65 |
|
66 |
|
67 |
-
|
68 |
-
|
|
|
69 |
|
70 |
class SE_Connect(nn.Module):
|
71 |
def __init__(self, channels, se_bottleneck_dim=128):
|
@@ -82,8 +85,8 @@ class SE_Connect(nn.Module):
|
|
82 |
return out
|
83 |
|
84 |
|
85 |
-
|
86 |
-
|
87 |
|
88 |
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
89 |
# return nn.Sequential(
|
@@ -93,6 +96,7 @@ class SE_Connect(nn.Module):
|
|
93 |
# SE_Connect(channels)
|
94 |
# )
|
95 |
|
|
|
96 |
class SE_Res2Block(nn.Module):
|
97 |
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
|
98 |
super().__init__()
|
@@ -122,8 +126,9 @@ class SE_Res2Block(nn.Module):
|
|
122 |
return x + residual
|
123 |
|
124 |
|
125 |
-
|
126 |
-
|
|
|
127 |
|
128 |
class AttentiveStatsPool(nn.Module):
|
129 |
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
@@ -138,7 +143,6 @@ class AttentiveStatsPool(nn.Module):
|
|
138 |
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
|
139 |
|
140 |
def forward(self, x):
|
141 |
-
|
142 |
if self.global_context_att:
|
143 |
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
144 |
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
@@ -151,38 +155,52 @@ class AttentiveStatsPool(nn.Module):
|
|
151 |
# alpha = F.relu(self.linear1(x_in))
|
152 |
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
153 |
mean = torch.sum(alpha * x, dim=2)
|
154 |
-
residuals = torch.sum(alpha * (x
|
155 |
std = torch.sqrt(residuals.clamp(min=1e-9))
|
156 |
return torch.cat([mean, std], dim=1)
|
157 |
|
158 |
|
159 |
class ECAPA_TDNN(nn.Module):
|
160 |
-
def __init__(
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
super().__init__()
|
163 |
|
164 |
self.feat_type = feat_type
|
165 |
self.feature_selection = feature_selection
|
166 |
self.update_extract = update_extract
|
167 |
self.sr = sr
|
168 |
-
|
169 |
-
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
|
170 |
try:
|
171 |
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
|
172 |
-
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source=
|
173 |
-
except:
|
174 |
-
self.feature_extract = torch.hub.load(
|
175 |
|
176 |
-
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
|
|
|
|
177 |
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
178 |
-
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
|
|
|
|
179 |
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
|
180 |
|
181 |
self.feat_num = self.get_feat_num()
|
182 |
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
183 |
|
184 |
-
if feat_type !=
|
185 |
-
freeze_list = [
|
186 |
for name, param in self.feature_extract.named_parameters():
|
187 |
for freeze_val in freeze_list:
|
188 |
if freeze_val in name:
|
@@ -198,18 +216,46 @@ class ECAPA_TDNN(nn.Module):
|
|
198 |
self.channels = [channels] * 4 + [1536]
|
199 |
|
200 |
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
201 |
-
self.layer2 = SE_Res2Block(
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
206 |
cat_channels = channels * 3
|
207 |
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
208 |
-
self.pooling = AttentiveStatsPool(
|
|
|
|
|
209 |
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
210 |
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
211 |
|
212 |
-
|
213 |
def get_feat_num(self):
|
214 |
self.feature_extract.eval()
|
215 |
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
@@ -226,12 +272,12 @@ class ECAPA_TDNN(nn.Module):
|
|
226 |
x = self.feature_extract([sample for sample in x])
|
227 |
else:
|
228 |
with torch.no_grad():
|
229 |
-
if self.feat_type ==
|
230 |
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
|
231 |
else:
|
232 |
x = self.feature_extract([sample for sample in x])
|
233 |
|
234 |
-
if self.feat_type ==
|
235 |
x = x.log()
|
236 |
|
237 |
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
@@ -263,6 +309,22 @@ class ECAPA_TDNN(nn.Module):
|
|
263 |
return out
|
264 |
|
265 |
|
266 |
-
def ECAPA_TDNN_SMALL(
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import torch.nn.functional as F
|
10 |
|
11 |
|
12 |
+
""" Res2Conv1d + BatchNorm1d + ReLU
|
13 |
+
"""
|
14 |
+
|
15 |
|
16 |
class Res2Conv1dReluBn(nn.Module):
|
17 |
+
"""
|
18 |
in_channels == out_channels == channels
|
19 |
+
"""
|
20 |
|
21 |
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
|
22 |
super().__init__()
|
|
|
52 |
return out
|
53 |
|
54 |
|
55 |
+
""" Conv1d + BatchNorm1d + ReLU
|
56 |
+
"""
|
57 |
+
|
58 |
|
59 |
class Conv1dReluBn(nn.Module):
|
60 |
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
|
|
|
66 |
return self.bn(F.relu(self.conv(x)))
|
67 |
|
68 |
|
69 |
+
""" The SE connection of 1D case.
|
70 |
+
"""
|
71 |
+
|
72 |
|
73 |
class SE_Connect(nn.Module):
|
74 |
def __init__(self, channels, se_bottleneck_dim=128):
|
|
|
85 |
return out
|
86 |
|
87 |
|
88 |
+
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
89 |
+
"""
|
90 |
|
91 |
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
92 |
# return nn.Sequential(
|
|
|
96 |
# SE_Connect(channels)
|
97 |
# )
|
98 |
|
99 |
+
|
100 |
class SE_Res2Block(nn.Module):
|
101 |
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
|
102 |
super().__init__()
|
|
|
126 |
return x + residual
|
127 |
|
128 |
|
129 |
+
""" Attentive weighted mean and standard deviation pooling.
|
130 |
+
"""
|
131 |
+
|
132 |
|
133 |
class AttentiveStatsPool(nn.Module):
|
134 |
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
|
|
143 |
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
|
144 |
|
145 |
def forward(self, x):
|
|
|
146 |
if self.global_context_att:
|
147 |
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
148 |
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
|
|
155 |
# alpha = F.relu(self.linear1(x_in))
|
156 |
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
157 |
mean = torch.sum(alpha * x, dim=2)
|
158 |
+
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
159 |
std = torch.sqrt(residuals.clamp(min=1e-9))
|
160 |
return torch.cat([mean, std], dim=1)
|
161 |
|
162 |
|
163 |
class ECAPA_TDNN(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
feat_dim=80,
|
167 |
+
channels=512,
|
168 |
+
emb_dim=192,
|
169 |
+
global_context_att=False,
|
170 |
+
feat_type="wavlm_large",
|
171 |
+
sr=16000,
|
172 |
+
feature_selection="hidden_states",
|
173 |
+
update_extract=False,
|
174 |
+
config_path=None,
|
175 |
+
):
|
176 |
super().__init__()
|
177 |
|
178 |
self.feat_type = feat_type
|
179 |
self.feature_selection = feature_selection
|
180 |
self.update_extract = update_extract
|
181 |
self.sr = sr
|
182 |
+
|
183 |
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
184 |
try:
|
185 |
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
|
186 |
+
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
|
187 |
+
except: # noqa: E722
|
188 |
+
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
|
189 |
|
190 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
191 |
+
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
|
192 |
+
):
|
193 |
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
194 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
195 |
+
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
|
196 |
+
):
|
197 |
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
|
198 |
|
199 |
self.feat_num = self.get_feat_num()
|
200 |
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
201 |
|
202 |
+
if feat_type != "fbank" and feat_type != "mfcc":
|
203 |
+
freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
|
204 |
for name, param in self.feature_extract.named_parameters():
|
205 |
for freeze_val in freeze_list:
|
206 |
if freeze_val in name:
|
|
|
216 |
self.channels = [channels] * 4 + [1536]
|
217 |
|
218 |
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
219 |
+
self.layer2 = SE_Res2Block(
|
220 |
+
self.channels[0],
|
221 |
+
self.channels[1],
|
222 |
+
kernel_size=3,
|
223 |
+
stride=1,
|
224 |
+
padding=2,
|
225 |
+
dilation=2,
|
226 |
+
scale=8,
|
227 |
+
se_bottleneck_dim=128,
|
228 |
+
)
|
229 |
+
self.layer3 = SE_Res2Block(
|
230 |
+
self.channels[1],
|
231 |
+
self.channels[2],
|
232 |
+
kernel_size=3,
|
233 |
+
stride=1,
|
234 |
+
padding=3,
|
235 |
+
dilation=3,
|
236 |
+
scale=8,
|
237 |
+
se_bottleneck_dim=128,
|
238 |
+
)
|
239 |
+
self.layer4 = SE_Res2Block(
|
240 |
+
self.channels[2],
|
241 |
+
self.channels[3],
|
242 |
+
kernel_size=3,
|
243 |
+
stride=1,
|
244 |
+
padding=4,
|
245 |
+
dilation=4,
|
246 |
+
scale=8,
|
247 |
+
se_bottleneck_dim=128,
|
248 |
+
)
|
249 |
|
250 |
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
251 |
cat_channels = channels * 3
|
252 |
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
253 |
+
self.pooling = AttentiveStatsPool(
|
254 |
+
self.channels[-1], attention_channels=128, global_context_att=global_context_att
|
255 |
+
)
|
256 |
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
257 |
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
258 |
|
|
|
259 |
def get_feat_num(self):
|
260 |
self.feature_extract.eval()
|
261 |
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
|
|
272 |
x = self.feature_extract([sample for sample in x])
|
273 |
else:
|
274 |
with torch.no_grad():
|
275 |
+
if self.feat_type == "fbank" or self.feat_type == "mfcc":
|
276 |
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
|
277 |
else:
|
278 |
x = self.feature_extract([sample for sample in x])
|
279 |
|
280 |
+
if self.feat_type == "fbank":
|
281 |
x = x.log()
|
282 |
|
283 |
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
|
|
309 |
return out
|
310 |
|
311 |
|
312 |
+
def ECAPA_TDNN_SMALL(
|
313 |
+
feat_dim,
|
314 |
+
emb_dim=256,
|
315 |
+
feat_type="wavlm_large",
|
316 |
+
sr=16000,
|
317 |
+
feature_selection="hidden_states",
|
318 |
+
update_extract=False,
|
319 |
+
config_path=None,
|
320 |
+
):
|
321 |
+
return ECAPA_TDNN(
|
322 |
+
feat_dim=feat_dim,
|
323 |
+
channels=512,
|
324 |
+
emb_dim=emb_dim,
|
325 |
+
feat_type=feat_type,
|
326 |
+
sr=sr,
|
327 |
+
feature_selection=feature_selection,
|
328 |
+
update_extract=update_extract,
|
329 |
+
config_path=config_path,
|
330 |
+
)
|
model/modules.py
CHANGED
@@ -21,39 +21,40 @@ from x_transformers.x_transformers import apply_rotary_pos_emb
|
|
21 |
|
22 |
# raw wav to mel spec
|
23 |
|
|
|
24 |
class MelSpec(nn.Module):
|
25 |
def __init__(
|
26 |
self,
|
27 |
-
filter_length
|
28 |
-
hop_length
|
29 |
-
win_length
|
30 |
-
n_mel_channels
|
31 |
-
target_sample_rate
|
32 |
-
normalize
|
33 |
-
power
|
34 |
-
norm
|
35 |
-
center
|
36 |
):
|
37 |
super().__init__()
|
38 |
self.n_mel_channels = n_mel_channels
|
39 |
|
40 |
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
41 |
-
sample_rate
|
42 |
-
n_fft
|
43 |
-
win_length
|
44 |
-
hop_length
|
45 |
-
n_mels
|
46 |
-
power
|
47 |
-
center
|
48 |
-
normalized
|
49 |
-
norm
|
50 |
)
|
51 |
|
52 |
-
self.register_buffer(
|
53 |
|
54 |
def forward(self, inp):
|
55 |
if len(inp.shape) == 3:
|
56 |
-
inp = inp.squeeze(1)
|
57 |
|
58 |
assert len(inp.shape) == 2
|
59 |
|
@@ -61,12 +62,13 @@ class MelSpec(nn.Module):
|
|
61 |
self.to(inp.device)
|
62 |
|
63 |
mel = self.mel_stft(inp)
|
64 |
-
mel = mel.clamp(min
|
65 |
return mel
|
66 |
-
|
67 |
|
68 |
# sinusoidal position embedding
|
69 |
|
|
|
70 |
class SinusPositionEmbedding(nn.Module):
|
71 |
def __init__(self, dim):
|
72 |
super().__init__()
|
@@ -84,35 +86,37 @@ class SinusPositionEmbedding(nn.Module):
|
|
84 |
|
85 |
# convolutional position embedding
|
86 |
|
|
|
87 |
class ConvPositionEmbedding(nn.Module):
|
88 |
-
def __init__(self, dim, kernel_size
|
89 |
super().__init__()
|
90 |
assert kernel_size % 2 != 0
|
91 |
self.conv1d = nn.Sequential(
|
92 |
-
nn.Conv1d(dim, dim, kernel_size, groups
|
93 |
nn.Mish(),
|
94 |
-
nn.Conv1d(dim, dim, kernel_size, groups
|
95 |
nn.Mish(),
|
96 |
)
|
97 |
|
98 |
-
def forward(self, x: float[
|
99 |
if mask is not None:
|
100 |
mask = mask[..., None]
|
101 |
-
x = x.masked_fill(~mask, 0.)
|
102 |
|
103 |
x = x.permute(0, 2, 1)
|
104 |
x = self.conv1d(x)
|
105 |
out = x.permute(0, 2, 1)
|
106 |
|
107 |
if mask is not None:
|
108 |
-
out = out.masked_fill(~mask, 0.)
|
109 |
|
110 |
return out
|
111 |
|
112 |
|
113 |
# rotary positional embedding related
|
114 |
|
115 |
-
|
|
|
116 |
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
117 |
# has some connection to NTK literature
|
118 |
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
@@ -125,12 +129,14 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
|
|
125 |
freqs_sin = torch.sin(freqs) # imaginary part
|
126 |
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
127 |
|
128 |
-
|
|
|
129 |
# length = length if isinstance(length, int) else length.max()
|
130 |
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
131 |
-
pos =
|
132 |
-
|
133 |
-
|
|
|
134 |
# avoid extra long error.
|
135 |
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
136 |
return pos
|
@@ -138,6 +144,7 @@ def get_pos_embed_indices(start, length, max_pos, scale=1.):
|
|
138 |
|
139 |
# Global Response Normalization layer (Instance Normalization ?)
|
140 |
|
|
|
141 |
class GRN(nn.Module):
|
142 |
def __init__(self, dim):
|
143 |
super().__init__()
|
@@ -153,6 +160,7 @@ class GRN(nn.Module):
|
|
153 |
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
154 |
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
155 |
|
|
|
156 |
class ConvNeXtV2Block(nn.Module):
|
157 |
def __init__(
|
158 |
self,
|
@@ -162,7 +170,9 @@ class ConvNeXtV2Block(nn.Module):
|
|
162 |
):
|
163 |
super().__init__()
|
164 |
padding = (dilation * (7 - 1)) // 2
|
165 |
-
self.dwconv = nn.Conv1d(
|
|
|
|
|
166 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
167 |
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
168 |
self.act = nn.GELU()
|
@@ -185,6 +195,7 @@ class ConvNeXtV2Block(nn.Module):
|
|
185 |
# AdaLayerNormZero
|
186 |
# return with modulated x for attn input, and params for later mlp modulation
|
187 |
|
|
|
188 |
class AdaLayerNormZero(nn.Module):
|
189 |
def __init__(self, dim):
|
190 |
super().__init__()
|
@@ -194,7 +205,7 @@ class AdaLayerNormZero(nn.Module):
|
|
194 |
|
195 |
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
196 |
|
197 |
-
def forward(self, x, emb
|
198 |
emb = self.linear(self.silu(emb))
|
199 |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
200 |
|
@@ -205,6 +216,7 @@ class AdaLayerNormZero(nn.Module):
|
|
205 |
# AdaLayerNormZero for final layer
|
206 |
# return only with modulated x for attn input, cuz no more mlp modulation
|
207 |
|
|
|
208 |
class AdaLayerNormZero_Final(nn.Module):
|
209 |
def __init__(self, dim):
|
210 |
super().__init__()
|
@@ -224,22 +236,16 @@ class AdaLayerNormZero_Final(nn.Module):
|
|
224 |
|
225 |
# FeedForward
|
226 |
|
|
|
227 |
class FeedForward(nn.Module):
|
228 |
-
def __init__(self, dim, dim_out
|
229 |
super().__init__()
|
230 |
inner_dim = int(dim * mult)
|
231 |
dim_out = dim_out if dim_out is not None else dim
|
232 |
|
233 |
activation = nn.GELU(approximate=approximate)
|
234 |
-
project_in = nn.Sequential(
|
235 |
-
|
236 |
-
activation
|
237 |
-
)
|
238 |
-
self.ff = nn.Sequential(
|
239 |
-
project_in,
|
240 |
-
nn.Dropout(dropout),
|
241 |
-
nn.Linear(inner_dim, dim_out)
|
242 |
-
)
|
243 |
|
244 |
def forward(self, x):
|
245 |
return self.ff(x)
|
@@ -248,6 +254,7 @@ class FeedForward(nn.Module):
|
|
248 |
# Attention with possible joint part
|
249 |
# modified from diffusers/src/diffusers/models/attention_processor.py
|
250 |
|
|
|
251 |
class Attention(nn.Module):
|
252 |
def __init__(
|
253 |
self,
|
@@ -256,8 +263,8 @@ class Attention(nn.Module):
|
|
256 |
heads: int = 8,
|
257 |
dim_head: int = 64,
|
258 |
dropout: float = 0.0,
|
259 |
-
context_dim: Optional[int] = None,
|
260 |
-
context_pre_only
|
261 |
):
|
262 |
super().__init__()
|
263 |
|
@@ -293,20 +300,21 @@ class Attention(nn.Module):
|
|
293 |
|
294 |
def forward(
|
295 |
self,
|
296 |
-
x: float[
|
297 |
-
c: float[
|
298 |
-
mask: bool[
|
299 |
-
rope
|
300 |
-
c_rope
|
301 |
) -> torch.Tensor:
|
302 |
if c is not None:
|
303 |
-
return self.processor(self, x, c
|
304 |
else:
|
305 |
-
return self.processor(self, x, mask
|
306 |
|
307 |
|
308 |
# Attention processor
|
309 |
|
|
|
310 |
class AttnProcessor:
|
311 |
def __init__(self):
|
312 |
pass
|
@@ -314,11 +322,10 @@ class AttnProcessor:
|
|
314 |
def __call__(
|
315 |
self,
|
316 |
attn: Attention,
|
317 |
-
x: float[
|
318 |
-
mask: bool[
|
319 |
-
rope
|
320 |
) -> torch.FloatTensor:
|
321 |
-
|
322 |
batch_size = x.shape[0]
|
323 |
|
324 |
# `sample` projections.
|
@@ -329,7 +336,7 @@ class AttnProcessor:
|
|
329 |
# apply rotary position embedding
|
330 |
if rope is not None:
|
331 |
freqs, xpos_scale = rope
|
332 |
-
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale
|
333 |
|
334 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
335 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
@@ -360,14 +367,15 @@ class AttnProcessor:
|
|
360 |
|
361 |
if mask is not None:
|
362 |
mask = mask.unsqueeze(-1)
|
363 |
-
x = x.masked_fill(~mask, 0.)
|
364 |
|
365 |
return x
|
366 |
-
|
367 |
|
368 |
# Joint Attention processor for MM-DiT
|
369 |
# modified from diffusers/src/diffusers/models/attention_processor.py
|
370 |
|
|
|
371 |
class JointAttnProcessor:
|
372 |
def __init__(self):
|
373 |
pass
|
@@ -375,11 +383,11 @@ class JointAttnProcessor:
|
|
375 |
def __call__(
|
376 |
self,
|
377 |
attn: Attention,
|
378 |
-
x: float[
|
379 |
-
c: float[
|
380 |
-
mask: bool[
|
381 |
-
rope
|
382 |
-
c_rope
|
383 |
) -> torch.FloatTensor:
|
384 |
residual = x
|
385 |
|
@@ -398,12 +406,12 @@ class JointAttnProcessor:
|
|
398 |
# apply rope for context and noised input independently
|
399 |
if rope is not None:
|
400 |
freqs, xpos_scale = rope
|
401 |
-
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale
|
402 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
403 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
404 |
if c_rope is not None:
|
405 |
freqs, xpos_scale = c_rope
|
406 |
-
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale
|
407 |
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
408 |
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
409 |
|
@@ -420,7 +428,7 @@ class JointAttnProcessor:
|
|
420 |
|
421 |
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
422 |
if mask is not None:
|
423 |
-
attn_mask = F.pad(mask, (0, c.shape[1]), value
|
424 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
425 |
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
426 |
else:
|
@@ -432,8 +440,8 @@ class JointAttnProcessor:
|
|
432 |
|
433 |
# Split the attention outputs.
|
434 |
x, c = (
|
435 |
-
x[:, :residual.shape[1]],
|
436 |
-
x[:, residual.shape[1]:],
|
437 |
)
|
438 |
|
439 |
# linear proj
|
@@ -445,7 +453,7 @@ class JointAttnProcessor:
|
|
445 |
|
446 |
if mask is not None:
|
447 |
mask = mask.unsqueeze(-1)
|
448 |
-
x = x.masked_fill(~mask, 0.)
|
449 |
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
450 |
|
451 |
return x, c
|
@@ -453,24 +461,24 @@ class JointAttnProcessor:
|
|
453 |
|
454 |
# DiT Block
|
455 |
|
456 |
-
class DiTBlock(nn.Module):
|
457 |
|
458 |
-
|
|
|
459 |
super().__init__()
|
460 |
-
|
461 |
self.attn_norm = AdaLayerNormZero(dim)
|
462 |
self.attn = Attention(
|
463 |
-
processor
|
464 |
-
dim
|
465 |
-
heads
|
466 |
-
dim_head
|
467 |
-
dropout
|
468 |
-
|
469 |
-
|
470 |
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
471 |
-
self.ff = FeedForward(dim
|
472 |
|
473 |
-
def forward(self, x, t, mask
|
474 |
# pre-norm & modulation for attention input
|
475 |
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
476 |
|
@@ -479,7 +487,7 @@ class DiTBlock(nn.Module):
|
|
479 |
|
480 |
# process attention output for input x
|
481 |
x = x + gate_msa.unsqueeze(1) * attn_output
|
482 |
-
|
483 |
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
484 |
ff_output = self.ff(norm)
|
485 |
x = x + gate_mlp.unsqueeze(1) * ff_output
|
@@ -489,8 +497,9 @@ class DiTBlock(nn.Module):
|
|
489 |
|
490 |
# MMDiT Block https://arxiv.org/abs/2403.03206
|
491 |
|
|
|
492 |
class MMDiTBlock(nn.Module):
|
493 |
-
r"""
|
494 |
modified from diffusers/src/diffusers/models/attention.py
|
495 |
|
496 |
notes.
|
@@ -499,33 +508,33 @@ class MMDiTBlock(nn.Module):
|
|
499 |
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
500 |
"""
|
501 |
|
502 |
-
def __init__(self, dim, heads, dim_head, ff_mult
|
503 |
super().__init__()
|
504 |
|
505 |
self.context_pre_only = context_pre_only
|
506 |
-
|
507 |
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
508 |
self.attn_norm_x = AdaLayerNormZero(dim)
|
509 |
self.attn = Attention(
|
510 |
-
processor
|
511 |
-
dim
|
512 |
-
heads
|
513 |
-
dim_head
|
514 |
-
dropout
|
515 |
-
context_dim
|
516 |
-
context_pre_only
|
517 |
-
|
518 |
|
519 |
if not context_pre_only:
|
520 |
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
521 |
-
self.ff_c = FeedForward(dim
|
522 |
else:
|
523 |
self.ff_norm_c = None
|
524 |
self.ff_c = None
|
525 |
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
526 |
-
self.ff_x = FeedForward(dim
|
527 |
|
528 |
-
def forward(self, x, c, t, mask
|
529 |
# pre-norm & modulation for attention input
|
530 |
if self.context_pre_only:
|
531 |
norm_c = self.attn_norm_c(c, t)
|
@@ -539,7 +548,7 @@ class MMDiTBlock(nn.Module):
|
|
539 |
# process attention output for context c
|
540 |
if self.context_pre_only:
|
541 |
c = None
|
542 |
-
else:
|
543 |
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
544 |
|
545 |
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
@@ -548,7 +557,7 @@ class MMDiTBlock(nn.Module):
|
|
548 |
|
549 |
# process attention output for input x
|
550 |
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
551 |
-
|
552 |
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
553 |
x_ff_output = self.ff_x(norm_x)
|
554 |
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
@@ -558,17 +567,14 @@ class MMDiTBlock(nn.Module):
|
|
558 |
|
559 |
# time step conditioning embedding
|
560 |
|
|
|
561 |
class TimestepEmbedding(nn.Module):
|
562 |
def __init__(self, dim, freq_embed_dim=256):
|
563 |
super().__init__()
|
564 |
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
565 |
-
self.time_mlp = nn.Sequential(
|
566 |
-
nn.Linear(freq_embed_dim, dim),
|
567 |
-
nn.SiLU(),
|
568 |
-
nn.Linear(dim, dim)
|
569 |
-
)
|
570 |
|
571 |
-
def forward(self, timestep: float[
|
572 |
time_hidden = self.time_embed(timestep)
|
573 |
time_hidden = time_hidden.to(timestep.dtype)
|
574 |
time = self.time_mlp(time_hidden) # b d
|
|
|
21 |
|
22 |
# raw wav to mel spec
|
23 |
|
24 |
+
|
25 |
class MelSpec(nn.Module):
|
26 |
def __init__(
|
27 |
self,
|
28 |
+
filter_length=1024,
|
29 |
+
hop_length=256,
|
30 |
+
win_length=1024,
|
31 |
+
n_mel_channels=100,
|
32 |
+
target_sample_rate=24_000,
|
33 |
+
normalize=False,
|
34 |
+
power=1,
|
35 |
+
norm=None,
|
36 |
+
center=True,
|
37 |
):
|
38 |
super().__init__()
|
39 |
self.n_mel_channels = n_mel_channels
|
40 |
|
41 |
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
42 |
+
sample_rate=target_sample_rate,
|
43 |
+
n_fft=filter_length,
|
44 |
+
win_length=win_length,
|
45 |
+
hop_length=hop_length,
|
46 |
+
n_mels=n_mel_channels,
|
47 |
+
power=power,
|
48 |
+
center=center,
|
49 |
+
normalized=normalize,
|
50 |
+
norm=norm,
|
51 |
)
|
52 |
|
53 |
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
54 |
|
55 |
def forward(self, inp):
|
56 |
if len(inp.shape) == 3:
|
57 |
+
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
|
58 |
|
59 |
assert len(inp.shape) == 2
|
60 |
|
|
|
62 |
self.to(inp.device)
|
63 |
|
64 |
mel = self.mel_stft(inp)
|
65 |
+
mel = mel.clamp(min=1e-5).log()
|
66 |
return mel
|
67 |
+
|
68 |
|
69 |
# sinusoidal position embedding
|
70 |
|
71 |
+
|
72 |
class SinusPositionEmbedding(nn.Module):
|
73 |
def __init__(self, dim):
|
74 |
super().__init__()
|
|
|
86 |
|
87 |
# convolutional position embedding
|
88 |
|
89 |
+
|
90 |
class ConvPositionEmbedding(nn.Module):
|
91 |
+
def __init__(self, dim, kernel_size=31, groups=16):
|
92 |
super().__init__()
|
93 |
assert kernel_size % 2 != 0
|
94 |
self.conv1d = nn.Sequential(
|
95 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
96 |
nn.Mish(),
|
97 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
98 |
nn.Mish(),
|
99 |
)
|
100 |
|
101 |
+
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
102 |
if mask is not None:
|
103 |
mask = mask[..., None]
|
104 |
+
x = x.masked_fill(~mask, 0.0)
|
105 |
|
106 |
x = x.permute(0, 2, 1)
|
107 |
x = self.conv1d(x)
|
108 |
out = x.permute(0, 2, 1)
|
109 |
|
110 |
if mask is not None:
|
111 |
+
out = out.masked_fill(~mask, 0.0)
|
112 |
|
113 |
return out
|
114 |
|
115 |
|
116 |
# rotary positional embedding related
|
117 |
|
118 |
+
|
119 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
120 |
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
121 |
# has some connection to NTK literature
|
122 |
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
|
129 |
freqs_sin = torch.sin(freqs) # imaginary part
|
130 |
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
131 |
|
132 |
+
|
133 |
+
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
134 |
# length = length if isinstance(length, int) else length.max()
|
135 |
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
136 |
+
pos = (
|
137 |
+
start.unsqueeze(1)
|
138 |
+
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
139 |
+
)
|
140 |
# avoid extra long error.
|
141 |
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
142 |
return pos
|
|
|
144 |
|
145 |
# Global Response Normalization layer (Instance Normalization ?)
|
146 |
|
147 |
+
|
148 |
class GRN(nn.Module):
|
149 |
def __init__(self, dim):
|
150 |
super().__init__()
|
|
|
160 |
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
161 |
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
162 |
|
163 |
+
|
164 |
class ConvNeXtV2Block(nn.Module):
|
165 |
def __init__(
|
166 |
self,
|
|
|
170 |
):
|
171 |
super().__init__()
|
172 |
padding = (dilation * (7 - 1)) // 2
|
173 |
+
self.dwconv = nn.Conv1d(
|
174 |
+
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
175 |
+
) # depthwise conv
|
176 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
177 |
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
178 |
self.act = nn.GELU()
|
|
|
195 |
# AdaLayerNormZero
|
196 |
# return with modulated x for attn input, and params for later mlp modulation
|
197 |
|
198 |
+
|
199 |
class AdaLayerNormZero(nn.Module):
|
200 |
def __init__(self, dim):
|
201 |
super().__init__()
|
|
|
205 |
|
206 |
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
207 |
|
208 |
+
def forward(self, x, emb=None):
|
209 |
emb = self.linear(self.silu(emb))
|
210 |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
211 |
|
|
|
216 |
# AdaLayerNormZero for final layer
|
217 |
# return only with modulated x for attn input, cuz no more mlp modulation
|
218 |
|
219 |
+
|
220 |
class AdaLayerNormZero_Final(nn.Module):
|
221 |
def __init__(self, dim):
|
222 |
super().__init__()
|
|
|
236 |
|
237 |
# FeedForward
|
238 |
|
239 |
+
|
240 |
class FeedForward(nn.Module):
|
241 |
+
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
|
242 |
super().__init__()
|
243 |
inner_dim = int(dim * mult)
|
244 |
dim_out = dim_out if dim_out is not None else dim
|
245 |
|
246 |
activation = nn.GELU(approximate=approximate)
|
247 |
+
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
248 |
+
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
def forward(self, x):
|
251 |
return self.ff(x)
|
|
|
254 |
# Attention with possible joint part
|
255 |
# modified from diffusers/src/diffusers/models/attention_processor.py
|
256 |
|
257 |
+
|
258 |
class Attention(nn.Module):
|
259 |
def __init__(
|
260 |
self,
|
|
|
263 |
heads: int = 8,
|
264 |
dim_head: int = 64,
|
265 |
dropout: float = 0.0,
|
266 |
+
context_dim: Optional[int] = None, # if not None -> joint attention
|
267 |
+
context_pre_only=None,
|
268 |
):
|
269 |
super().__init__()
|
270 |
|
|
|
300 |
|
301 |
def forward(
|
302 |
self,
|
303 |
+
x: float["b n d"], # noised input x # noqa: F722
|
304 |
+
c: float["b n d"] = None, # context c # noqa: F722
|
305 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
306 |
+
rope=None, # rotary position embedding for x
|
307 |
+
c_rope=None, # rotary position embedding for c
|
308 |
) -> torch.Tensor:
|
309 |
if c is not None:
|
310 |
+
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
311 |
else:
|
312 |
+
return self.processor(self, x, mask=mask, rope=rope)
|
313 |
|
314 |
|
315 |
# Attention processor
|
316 |
|
317 |
+
|
318 |
class AttnProcessor:
|
319 |
def __init__(self):
|
320 |
pass
|
|
|
322 |
def __call__(
|
323 |
self,
|
324 |
attn: Attention,
|
325 |
+
x: float["b n d"], # noised input x # noqa: F722
|
326 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
327 |
+
rope=None, # rotary position embedding
|
328 |
) -> torch.FloatTensor:
|
|
|
329 |
batch_size = x.shape[0]
|
330 |
|
331 |
# `sample` projections.
|
|
|
336 |
# apply rotary position embedding
|
337 |
if rope is not None:
|
338 |
freqs, xpos_scale = rope
|
339 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
340 |
|
341 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
342 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
|
|
367 |
|
368 |
if mask is not None:
|
369 |
mask = mask.unsqueeze(-1)
|
370 |
+
x = x.masked_fill(~mask, 0.0)
|
371 |
|
372 |
return x
|
373 |
+
|
374 |
|
375 |
# Joint Attention processor for MM-DiT
|
376 |
# modified from diffusers/src/diffusers/models/attention_processor.py
|
377 |
|
378 |
+
|
379 |
class JointAttnProcessor:
|
380 |
def __init__(self):
|
381 |
pass
|
|
|
383 |
def __call__(
|
384 |
self,
|
385 |
attn: Attention,
|
386 |
+
x: float["b n d"], # noised input x # noqa: F722
|
387 |
+
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
388 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
389 |
+
rope=None, # rotary position embedding for x
|
390 |
+
c_rope=None, # rotary position embedding for c
|
391 |
) -> torch.FloatTensor:
|
392 |
residual = x
|
393 |
|
|
|
406 |
# apply rope for context and noised input independently
|
407 |
if rope is not None:
|
408 |
freqs, xpos_scale = rope
|
409 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
410 |
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
411 |
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
412 |
if c_rope is not None:
|
413 |
freqs, xpos_scale = c_rope
|
414 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
415 |
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
416 |
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
417 |
|
|
|
428 |
|
429 |
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
430 |
if mask is not None:
|
431 |
+
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
432 |
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
433 |
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
434 |
else:
|
|
|
440 |
|
441 |
# Split the attention outputs.
|
442 |
x, c = (
|
443 |
+
x[:, : residual.shape[1]],
|
444 |
+
x[:, residual.shape[1] :],
|
445 |
)
|
446 |
|
447 |
# linear proj
|
|
|
453 |
|
454 |
if mask is not None:
|
455 |
mask = mask.unsqueeze(-1)
|
456 |
+
x = x.masked_fill(~mask, 0.0)
|
457 |
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
458 |
|
459 |
return x, c
|
|
|
461 |
|
462 |
# DiT Block
|
463 |
|
|
|
464 |
|
465 |
+
class DiTBlock(nn.Module):
|
466 |
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
467 |
super().__init__()
|
468 |
+
|
469 |
self.attn_norm = AdaLayerNormZero(dim)
|
470 |
self.attn = Attention(
|
471 |
+
processor=AttnProcessor(),
|
472 |
+
dim=dim,
|
473 |
+
heads=heads,
|
474 |
+
dim_head=dim_head,
|
475 |
+
dropout=dropout,
|
476 |
+
)
|
477 |
+
|
478 |
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
479 |
+
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
480 |
|
481 |
+
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
482 |
# pre-norm & modulation for attention input
|
483 |
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
484 |
|
|
|
487 |
|
488 |
# process attention output for input x
|
489 |
x = x + gate_msa.unsqueeze(1) * attn_output
|
490 |
+
|
491 |
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
492 |
ff_output = self.ff(norm)
|
493 |
x = x + gate_mlp.unsqueeze(1) * ff_output
|
|
|
497 |
|
498 |
# MMDiT Block https://arxiv.org/abs/2403.03206
|
499 |
|
500 |
+
|
501 |
class MMDiTBlock(nn.Module):
|
502 |
+
r"""
|
503 |
modified from diffusers/src/diffusers/models/attention.py
|
504 |
|
505 |
notes.
|
|
|
508 |
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
509 |
"""
|
510 |
|
511 |
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
|
512 |
super().__init__()
|
513 |
|
514 |
self.context_pre_only = context_pre_only
|
515 |
+
|
516 |
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
517 |
self.attn_norm_x = AdaLayerNormZero(dim)
|
518 |
self.attn = Attention(
|
519 |
+
processor=JointAttnProcessor(),
|
520 |
+
dim=dim,
|
521 |
+
heads=heads,
|
522 |
+
dim_head=dim_head,
|
523 |
+
dropout=dropout,
|
524 |
+
context_dim=dim,
|
525 |
+
context_pre_only=context_pre_only,
|
526 |
+
)
|
527 |
|
528 |
if not context_pre_only:
|
529 |
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
530 |
+
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
531 |
else:
|
532 |
self.ff_norm_c = None
|
533 |
self.ff_c = None
|
534 |
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
535 |
+
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
536 |
|
537 |
+
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
|
538 |
# pre-norm & modulation for attention input
|
539 |
if self.context_pre_only:
|
540 |
norm_c = self.attn_norm_c(c, t)
|
|
|
548 |
# process attention output for context c
|
549 |
if self.context_pre_only:
|
550 |
c = None
|
551 |
+
else: # if not last layer
|
552 |
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
553 |
|
554 |
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
|
|
557 |
|
558 |
# process attention output for input x
|
559 |
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
560 |
+
|
561 |
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
562 |
x_ff_output = self.ff_x(norm_x)
|
563 |
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
|
|
567 |
|
568 |
# time step conditioning embedding
|
569 |
|
570 |
+
|
571 |
class TimestepEmbedding(nn.Module):
|
572 |
def __init__(self, dim, freq_embed_dim=256):
|
573 |
super().__init__()
|
574 |
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
575 |
+
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
|
|
|
|
|
|
|
|
576 |
|
577 |
+
def forward(self, timestep: float["b"]): # noqa: F821
|
578 |
time_hidden = self.time_embed(timestep)
|
579 |
time_hidden = time_hidden.to(timestep.dtype)
|
580 |
time = self.time_mlp(time_hidden) # b d
|
model/trainer.py
CHANGED
@@ -22,71 +22,69 @@ from model.dataset import DynamicBatchSampler, collate_fn
|
|
22 |
|
23 |
# trainer
|
24 |
|
|
|
25 |
class Trainer:
|
26 |
def __init__(
|
27 |
self,
|
28 |
model: CFM,
|
29 |
epochs,
|
30 |
learning_rate,
|
31 |
-
num_warmup_updates
|
32 |
-
save_per_updates
|
33 |
-
checkpoint_path
|
34 |
-
batch_size
|
35 |
batch_size_type: str = "sample",
|
36 |
-
max_samples
|
37 |
-
grad_accumulation_steps
|
38 |
-
max_grad_norm
|
39 |
noise_scheduler: str | None = None,
|
40 |
duration_predictor: torch.nn.Module | None = None,
|
41 |
-
wandb_project
|
42 |
-
wandb_run_name
|
43 |
wandb_resume_id: str = None,
|
44 |
-
last_per_steps
|
45 |
accelerate_kwargs: dict = dict(),
|
46 |
ema_kwargs: dict = dict(),
|
47 |
bnb_optimizer: bool = False,
|
48 |
):
|
49 |
-
|
50 |
-
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
|
51 |
|
52 |
logger = "wandb" if wandb.api.api_key else None
|
53 |
print(f"Using logger: {logger}")
|
54 |
|
55 |
self.accelerator = Accelerator(
|
56 |
-
log_with
|
57 |
-
kwargs_handlers
|
58 |
-
gradient_accumulation_steps
|
59 |
-
**accelerate_kwargs
|
60 |
)
|
61 |
|
62 |
if logger == "wandb":
|
63 |
if exists(wandb_resume_id):
|
64 |
-
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name,
|
65 |
else:
|
66 |
-
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
|
67 |
self.accelerator.init_trackers(
|
68 |
-
project_name
|
69 |
init_kwargs=init_kwargs,
|
70 |
-
config={
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
81 |
|
82 |
self.model = model
|
83 |
|
84 |
if self.is_main:
|
85 |
-
self.ema_model = EMA(
|
86 |
-
model,
|
87 |
-
include_online_model = False,
|
88 |
-
**ema_kwargs
|
89 |
-
)
|
90 |
|
91 |
self.ema_model.to(self.accelerator.device)
|
92 |
|
@@ -94,7 +92,7 @@ class Trainer:
|
|
94 |
self.num_warmup_updates = num_warmup_updates
|
95 |
self.save_per_updates = save_per_updates
|
96 |
self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
|
97 |
-
self.checkpoint_path = default(checkpoint_path,
|
98 |
|
99 |
self.batch_size = batch_size
|
100 |
self.batch_size_type = batch_size_type
|
@@ -108,12 +106,11 @@ class Trainer:
|
|
108 |
|
109 |
if bnb_optimizer:
|
110 |
import bitsandbytes as bnb
|
|
|
111 |
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
112 |
else:
|
113 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
114 |
-
self.model, self.optimizer = self.accelerator.prepare(
|
115 |
-
self.model, self.optimizer
|
116 |
-
)
|
117 |
|
118 |
@property
|
119 |
def is_main(self):
|
@@ -123,81 +120,112 @@ class Trainer:
|
|
123 |
self.accelerator.wait_for_everyone()
|
124 |
if self.is_main:
|
125 |
checkpoint = dict(
|
126 |
-
model_state_dict
|
127 |
-
optimizer_state_dict
|
128 |
-
ema_model_state_dict
|
129 |
-
scheduler_state_dict
|
130 |
-
step
|
131 |
)
|
132 |
if not os.path.exists(self.checkpoint_path):
|
133 |
os.makedirs(self.checkpoint_path)
|
134 |
-
if last
|
135 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
136 |
print(f"Saved last checkpoint at step {step}")
|
137 |
else:
|
138 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
|
139 |
|
140 |
def load_checkpoint(self):
|
141 |
-
if
|
|
|
|
|
|
|
|
|
142 |
return 0
|
143 |
-
|
144 |
self.accelerator.wait_for_everyone()
|
145 |
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
146 |
latest_checkpoint = "model_last.pt"
|
147 |
else:
|
148 |
-
latest_checkpoint = sorted(
|
|
|
|
|
|
|
149 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
150 |
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
|
151 |
|
152 |
if self.is_main:
|
153 |
-
self.ema_model.load_state_dict(checkpoint[
|
154 |
|
155 |
-
if
|
156 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint[
|
157 |
-
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint[
|
158 |
if self.scheduler:
|
159 |
-
self.scheduler.load_state_dict(checkpoint[
|
160 |
-
step = checkpoint[
|
161 |
else:
|
162 |
-
checkpoint[
|
163 |
-
|
|
|
|
|
|
|
|
|
164 |
step = 0
|
165 |
|
166 |
-
del checkpoint
|
|
|
167 |
return step
|
168 |
|
169 |
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
170 |
-
|
171 |
if exists(resumable_with_seed):
|
172 |
generator = torch.Generator()
|
173 |
generator.manual_seed(resumable_with_seed)
|
174 |
-
else:
|
175 |
generator = None
|
176 |
|
177 |
if self.batch_size_type == "sample":
|
178 |
-
train_dataloader = DataLoader(
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
elif self.batch_size_type == "frame":
|
181 |
self.accelerator.even_batches = False
|
182 |
sampler = SequentialSampler(train_dataset)
|
183 |
-
batch_sampler = DynamicBatchSampler(
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
else:
|
187 |
raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
|
188 |
-
|
189 |
# accelerator.prepare() dispatches batches to devices;
|
190 |
# which means the length of dataloader calculated before, should consider the number of devices
|
191 |
-
warmup_steps =
|
192 |
-
|
|
|
|
|
193 |
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
194 |
decay_steps = total_steps - warmup_steps
|
195 |
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
|
196 |
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
|
197 |
-
self.scheduler = SequentialLR(
|
198 |
-
|
199 |
-
|
200 |
-
train_dataloader, self.scheduler = self.accelerator.prepare(
|
|
|
|
|
201 |
start_step = self.load_checkpoint()
|
202 |
global_step = start_step
|
203 |
|
@@ -212,23 +240,36 @@ class Trainer:
|
|
212 |
for epoch in range(skipped_epoch, self.epochs):
|
213 |
self.model.train()
|
214 |
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
215 |
-
progress_bar = tqdm(
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
else:
|
218 |
-
progress_bar = tqdm(
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
for batch in progress_bar:
|
221 |
with self.accelerator.accumulate(self.model):
|
222 |
-
text_inputs = batch[
|
223 |
-
mel_spec = batch[
|
224 |
mel_lengths = batch["mel_lengths"]
|
225 |
|
226 |
# TODO. add duration predictor training
|
227 |
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
|
228 |
-
dur_loss = self.duration_predictor(mel_spec, lens=batch.get(
|
229 |
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
|
230 |
|
231 |
-
loss, cond, pred = self.model(
|
|
|
|
|
232 |
self.accelerator.backward(loss)
|
233 |
|
234 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
@@ -245,13 +286,13 @@ class Trainer:
|
|
245 |
|
246 |
if self.accelerator.is_local_main_process:
|
247 |
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
|
248 |
-
|
249 |
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
250 |
-
|
251 |
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
|
252 |
self.save_checkpoint(global_step)
|
253 |
-
|
254 |
if global_step % self.last_per_steps == 0:
|
255 |
self.save_checkpoint(global_step, last=True)
|
256 |
-
|
257 |
self.accelerator.end_training()
|
|
|
22 |
|
23 |
# trainer
|
24 |
|
25 |
+
|
26 |
class Trainer:
|
27 |
def __init__(
|
28 |
self,
|
29 |
model: CFM,
|
30 |
epochs,
|
31 |
learning_rate,
|
32 |
+
num_warmup_updates=20000,
|
33 |
+
save_per_updates=1000,
|
34 |
+
checkpoint_path=None,
|
35 |
+
batch_size=32,
|
36 |
batch_size_type: str = "sample",
|
37 |
+
max_samples=32,
|
38 |
+
grad_accumulation_steps=1,
|
39 |
+
max_grad_norm=1.0,
|
40 |
noise_scheduler: str | None = None,
|
41 |
duration_predictor: torch.nn.Module | None = None,
|
42 |
+
wandb_project="test_e2-tts",
|
43 |
+
wandb_run_name="test_run",
|
44 |
wandb_resume_id: str = None,
|
45 |
+
last_per_steps=None,
|
46 |
accelerate_kwargs: dict = dict(),
|
47 |
ema_kwargs: dict = dict(),
|
48 |
bnb_optimizer: bool = False,
|
49 |
):
|
50 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
|
51 |
|
52 |
logger = "wandb" if wandb.api.api_key else None
|
53 |
print(f"Using logger: {logger}")
|
54 |
|
55 |
self.accelerator = Accelerator(
|
56 |
+
log_with=logger,
|
57 |
+
kwargs_handlers=[ddp_kwargs],
|
58 |
+
gradient_accumulation_steps=grad_accumulation_steps,
|
59 |
+
**accelerate_kwargs,
|
60 |
)
|
61 |
|
62 |
if logger == "wandb":
|
63 |
if exists(wandb_resume_id):
|
64 |
+
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
|
65 |
else:
|
66 |
+
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
67 |
self.accelerator.init_trackers(
|
68 |
+
project_name=wandb_project,
|
69 |
init_kwargs=init_kwargs,
|
70 |
+
config={
|
71 |
+
"epochs": epochs,
|
72 |
+
"learning_rate": learning_rate,
|
73 |
+
"num_warmup_updates": num_warmup_updates,
|
74 |
+
"batch_size": batch_size,
|
75 |
+
"batch_size_type": batch_size_type,
|
76 |
+
"max_samples": max_samples,
|
77 |
+
"grad_accumulation_steps": grad_accumulation_steps,
|
78 |
+
"max_grad_norm": max_grad_norm,
|
79 |
+
"gpus": self.accelerator.num_processes,
|
80 |
+
"noise_scheduler": noise_scheduler,
|
81 |
+
},
|
82 |
+
)
|
83 |
|
84 |
self.model = model
|
85 |
|
86 |
if self.is_main:
|
87 |
+
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
|
|
|
|
|
|
|
|
|
88 |
|
89 |
self.ema_model.to(self.accelerator.device)
|
90 |
|
|
|
92 |
self.num_warmup_updates = num_warmup_updates
|
93 |
self.save_per_updates = save_per_updates
|
94 |
self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
|
95 |
+
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
96 |
|
97 |
self.batch_size = batch_size
|
98 |
self.batch_size_type = batch_size_type
|
|
|
106 |
|
107 |
if bnb_optimizer:
|
108 |
import bitsandbytes as bnb
|
109 |
+
|
110 |
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
111 |
else:
|
112 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
113 |
+
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
|
|
|
|
114 |
|
115 |
@property
|
116 |
def is_main(self):
|
|
|
120 |
self.accelerator.wait_for_everyone()
|
121 |
if self.is_main:
|
122 |
checkpoint = dict(
|
123 |
+
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
124 |
+
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
125 |
+
ema_model_state_dict=self.ema_model.state_dict(),
|
126 |
+
scheduler_state_dict=self.scheduler.state_dict(),
|
127 |
+
step=step,
|
128 |
)
|
129 |
if not os.path.exists(self.checkpoint_path):
|
130 |
os.makedirs(self.checkpoint_path)
|
131 |
+
if last:
|
132 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
133 |
print(f"Saved last checkpoint at step {step}")
|
134 |
else:
|
135 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
|
136 |
|
137 |
def load_checkpoint(self):
|
138 |
+
if (
|
139 |
+
not exists(self.checkpoint_path)
|
140 |
+
or not os.path.exists(self.checkpoint_path)
|
141 |
+
or not os.listdir(self.checkpoint_path)
|
142 |
+
):
|
143 |
return 0
|
144 |
+
|
145 |
self.accelerator.wait_for_everyone()
|
146 |
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
147 |
latest_checkpoint = "model_last.pt"
|
148 |
else:
|
149 |
+
latest_checkpoint = sorted(
|
150 |
+
[f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
|
151 |
+
key=lambda x: int("".join(filter(str.isdigit, x))),
|
152 |
+
)[-1]
|
153 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
154 |
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
|
155 |
|
156 |
if self.is_main:
|
157 |
+
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
|
158 |
|
159 |
+
if "step" in checkpoint:
|
160 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
161 |
+
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
|
162 |
if self.scheduler:
|
163 |
+
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
164 |
+
step = checkpoint["step"]
|
165 |
else:
|
166 |
+
checkpoint["model_state_dict"] = {
|
167 |
+
k.replace("ema_model.", ""): v
|
168 |
+
for k, v in checkpoint["ema_model_state_dict"].items()
|
169 |
+
if k not in ["initted", "step"]
|
170 |
+
}
|
171 |
+
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
172 |
step = 0
|
173 |
|
174 |
+
del checkpoint
|
175 |
+
gc.collect()
|
176 |
return step
|
177 |
|
178 |
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
|
|
179 |
if exists(resumable_with_seed):
|
180 |
generator = torch.Generator()
|
181 |
generator.manual_seed(resumable_with_seed)
|
182 |
+
else:
|
183 |
generator = None
|
184 |
|
185 |
if self.batch_size_type == "sample":
|
186 |
+
train_dataloader = DataLoader(
|
187 |
+
train_dataset,
|
188 |
+
collate_fn=collate_fn,
|
189 |
+
num_workers=num_workers,
|
190 |
+
pin_memory=True,
|
191 |
+
persistent_workers=True,
|
192 |
+
batch_size=self.batch_size,
|
193 |
+
shuffle=True,
|
194 |
+
generator=generator,
|
195 |
+
)
|
196 |
elif self.batch_size_type == "frame":
|
197 |
self.accelerator.even_batches = False
|
198 |
sampler = SequentialSampler(train_dataset)
|
199 |
+
batch_sampler = DynamicBatchSampler(
|
200 |
+
sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
|
201 |
+
)
|
202 |
+
train_dataloader = DataLoader(
|
203 |
+
train_dataset,
|
204 |
+
collate_fn=collate_fn,
|
205 |
+
num_workers=num_workers,
|
206 |
+
pin_memory=True,
|
207 |
+
persistent_workers=True,
|
208 |
+
batch_sampler=batch_sampler,
|
209 |
+
)
|
210 |
else:
|
211 |
raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
|
212 |
+
|
213 |
# accelerator.prepare() dispatches batches to devices;
|
214 |
# which means the length of dataloader calculated before, should consider the number of devices
|
215 |
+
warmup_steps = (
|
216 |
+
self.num_warmup_updates * self.accelerator.num_processes
|
217 |
+
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
218 |
+
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
219 |
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
220 |
decay_steps = total_steps - warmup_steps
|
221 |
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
|
222 |
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
|
223 |
+
self.scheduler = SequentialLR(
|
224 |
+
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
|
225 |
+
)
|
226 |
+
train_dataloader, self.scheduler = self.accelerator.prepare(
|
227 |
+
train_dataloader, self.scheduler
|
228 |
+
) # actual steps = 1 gpu steps / gpus
|
229 |
start_step = self.load_checkpoint()
|
230 |
global_step = start_step
|
231 |
|
|
|
240 |
for epoch in range(skipped_epoch, self.epochs):
|
241 |
self.model.train()
|
242 |
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
243 |
+
progress_bar = tqdm(
|
244 |
+
skipped_dataloader,
|
245 |
+
desc=f"Epoch {epoch+1}/{self.epochs}",
|
246 |
+
unit="step",
|
247 |
+
disable=not self.accelerator.is_local_main_process,
|
248 |
+
initial=skipped_batch,
|
249 |
+
total=orig_epoch_step,
|
250 |
+
)
|
251 |
else:
|
252 |
+
progress_bar = tqdm(
|
253 |
+
train_dataloader,
|
254 |
+
desc=f"Epoch {epoch+1}/{self.epochs}",
|
255 |
+
unit="step",
|
256 |
+
disable=not self.accelerator.is_local_main_process,
|
257 |
+
)
|
258 |
|
259 |
for batch in progress_bar:
|
260 |
with self.accelerator.accumulate(self.model):
|
261 |
+
text_inputs = batch["text"]
|
262 |
+
mel_spec = batch["mel"].permute(0, 2, 1)
|
263 |
mel_lengths = batch["mel_lengths"]
|
264 |
|
265 |
# TODO. add duration predictor training
|
266 |
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
|
267 |
+
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
|
268 |
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
|
269 |
|
270 |
+
loss, cond, pred = self.model(
|
271 |
+
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
|
272 |
+
)
|
273 |
self.accelerator.backward(loss)
|
274 |
|
275 |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
|
|
286 |
|
287 |
if self.accelerator.is_local_main_process:
|
288 |
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
|
289 |
+
|
290 |
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
291 |
+
|
292 |
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
|
293 |
self.save_checkpoint(global_step)
|
294 |
+
|
295 |
if global_step % self.last_per_steps == 0:
|
296 |
self.save_checkpoint(global_step, last=True)
|
297 |
+
|
298 |
self.accelerator.end_training()
|
model/utils.py
CHANGED
@@ -8,6 +8,7 @@ from tqdm import tqdm
|
|
8 |
from collections import defaultdict
|
9 |
|
10 |
import matplotlib
|
|
|
11 |
matplotlib.use("Agg")
|
12 |
import matplotlib.pylab as plt
|
13 |
|
@@ -25,109 +26,102 @@ from model.modules import MelSpec
|
|
25 |
|
26 |
# seed everything
|
27 |
|
28 |
-
|
|
|
29 |
random.seed(seed)
|
30 |
-
os.environ[
|
31 |
torch.manual_seed(seed)
|
32 |
torch.cuda.manual_seed(seed)
|
33 |
torch.cuda.manual_seed_all(seed)
|
34 |
torch.backends.cudnn.deterministic = True
|
35 |
torch.backends.cudnn.benchmark = False
|
36 |
|
|
|
37 |
# helpers
|
38 |
|
|
|
39 |
def exists(v):
|
40 |
return v is not None
|
41 |
|
|
|
42 |
def default(v, d):
|
43 |
return v if exists(v) else d
|
44 |
|
|
|
45 |
# tensor helpers
|
46 |
|
47 |
-
def lens_to_mask(
|
48 |
-
t: int['b'],
|
49 |
-
length: int | None = None
|
50 |
-
) -> bool['b n']:
|
51 |
|
|
|
52 |
if not exists(length):
|
53 |
length = t.amax()
|
54 |
|
55 |
-
seq = torch.arange(length, device
|
56 |
return seq[None, :] < t[:, None]
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
):
|
63 |
-
max_seq_len = seq_len.max().item()
|
64 |
-
seq = torch.arange(max_seq_len, device = start.device).long()
|
65 |
start_mask = seq[None, :] >= start[:, None]
|
66 |
end_mask = seq[None, :] < end[:, None]
|
67 |
return start_mask & end_mask
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
frac_lengths: float['b']
|
72 |
-
):
|
73 |
lengths = (frac_lengths * seq_len).long()
|
74 |
max_start = seq_len - lengths
|
75 |
|
76 |
rand = torch.rand_like(frac_lengths)
|
77 |
-
start = (max_start * rand).long().clamp(min
|
78 |
end = start + lengths
|
79 |
|
80 |
return mask_from_start_end_indices(seq_len, start, end)
|
81 |
|
82 |
-
def maybe_masked_mean(
|
83 |
-
t: float['b n d'],
|
84 |
-
mask: bool['b n'] = None
|
85 |
-
) -> float['b d']:
|
86 |
|
|
|
87 |
if not exists(mask):
|
88 |
-
return t.mean(dim
|
89 |
|
90 |
-
t = torch.where(mask[:, :, None], t, torch.tensor(0
|
91 |
num = t.sum(dim=1)
|
92 |
den = mask.float().sum(dim=1)
|
93 |
|
94 |
-
return num / den.clamp(min=1.)
|
95 |
|
96 |
|
97 |
# simple utf-8 tokenizer, since paper went character based
|
98 |
-
def list_str_to_tensor(
|
99 |
-
|
100 |
-
padding_value =
|
101 |
-
) -> int['b nt']:
|
102 |
-
list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
|
103 |
-
text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
|
104 |
return text
|
105 |
|
|
|
106 |
# char tokenizer, based on custom dataset's extracted .txt file
|
107 |
def list_str_to_idx(
|
108 |
text: list[str] | list[list[str]],
|
109 |
vocab_char_map: dict[str, int], # {char: idx}
|
110 |
-
padding_value
|
111 |
-
) -> int[
|
112 |
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
113 |
-
text = pad_sequence(list_idx_tensors, padding_value
|
114 |
return text
|
115 |
|
116 |
|
117 |
# Get tokenizer
|
118 |
|
|
|
119 |
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
120 |
-
|
121 |
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
122 |
- "char" for char-wise tokenizer, need .txt vocab_file
|
123 |
- "byte" for utf-8 tokenizer
|
124 |
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
125 |
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
126 |
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
127 |
-
- if use "byte", set to 256 (unicode byte range)
|
128 |
-
|
129 |
if tokenizer in ["pinyin", "char"]:
|
130 |
-
with open
|
131 |
vocab_char_map = {}
|
132 |
for i, char in enumerate(f):
|
133 |
vocab_char_map[char[:-1]] = i
|
@@ -138,7 +132,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
138 |
vocab_char_map = None
|
139 |
vocab_size = 256
|
140 |
elif tokenizer == "custom":
|
141 |
-
with open
|
142 |
vocab_char_map = {}
|
143 |
for i, char in enumerate(f):
|
144 |
vocab_char_map[char[:-1]] = i
|
@@ -149,16 +143,19 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
149 |
|
150 |
# convert char to pinyin
|
151 |
|
152 |
-
|
|
|
153 |
final_text_list = []
|
154 |
-
god_knows_why_en_testset_contains_zh_quote = str.maketrans(
|
155 |
-
|
|
|
|
|
156 |
for text in text_list:
|
157 |
char_list = []
|
158 |
text = text.translate(god_knows_why_en_testset_contains_zh_quote)
|
159 |
text = text.translate(custom_trans)
|
160 |
for seg in jieba.cut(text):
|
161 |
-
seg_byte_len = len(bytes(seg,
|
162 |
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
163 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
164 |
char_list.append(" ")
|
@@ -187,7 +184,7 @@ def convert_char_to_pinyin(text_list, polyphone = True):
|
|
187 |
# save spectrogram
|
188 |
def save_spectrogram(spectrogram, path):
|
189 |
plt.figure(figsize=(12, 4))
|
190 |
-
plt.imshow(spectrogram, origin=
|
191 |
plt.colorbar()
|
192 |
plt.savefig(path)
|
193 |
plt.close()
|
@@ -195,13 +192,15 @@ def save_spectrogram(spectrogram, path):
|
|
195 |
|
196 |
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
197 |
def get_seedtts_testset_metainfo(metalst):
|
198 |
-
f = open(metalst)
|
|
|
|
|
199 |
metainfo = []
|
200 |
for line in lines:
|
201 |
-
if len(line.strip().split(
|
202 |
-
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split(
|
203 |
-
elif len(line.strip().split(
|
204 |
-
utt, prompt_text, prompt_wav, gt_text = line.strip().split(
|
205 |
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
206 |
if not os.path.isabs(prompt_wav):
|
207 |
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
@@ -211,18 +210,20 @@ def get_seedtts_testset_metainfo(metalst):
|
|
211 |
|
212 |
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
|
213 |
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
|
214 |
-
f = open(metalst)
|
|
|
|
|
215 |
metainfo = []
|
216 |
for line in lines:
|
217 |
-
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split(
|
218 |
|
219 |
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
220 |
-
ref_spk_id, ref_chaptr_id, _ =
|
221 |
-
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt +
|
222 |
|
223 |
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
224 |
-
gen_spk_id, gen_chaptr_id, _ =
|
225 |
-
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt +
|
226 |
|
227 |
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
|
228 |
|
@@ -234,7 +235,7 @@ def padded_mel_batch(ref_mels):
|
|
234 |
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
235 |
padded_ref_mels = []
|
236 |
for mel in ref_mels:
|
237 |
-
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value
|
238 |
padded_ref_mels.append(padded_ref_mel)
|
239 |
padded_ref_mels = torch.stack(padded_ref_mels)
|
240 |
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
@@ -243,12 +244,21 @@ def padded_mel_batch(ref_mels):
|
|
243 |
|
244 |
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
245 |
|
|
|
246 |
def get_inference_prompt(
|
247 |
-
metainfo,
|
248 |
-
speed
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
):
|
253 |
prompts_all = []
|
254 |
|
@@ -256,13 +266,15 @@ def get_inference_prompt(
|
|
256 |
max_tokens = max_secs * target_sample_rate // hop_length
|
257 |
|
258 |
batch_accum = [0] * num_buckets
|
259 |
-
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list =
|
260 |
-
|
|
|
261 |
|
262 |
-
mel_spectrogram = MelSpec(
|
|
|
|
|
263 |
|
264 |
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
|
265 |
-
|
266 |
# Audio
|
267 |
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
268 |
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
@@ -274,11 +286,11 @@ def get_inference_prompt(
|
|
274 |
ref_audio = resampler(ref_audio)
|
275 |
|
276 |
# Text
|
277 |
-
if len(prompt_text[-1].encode(
|
278 |
prompt_text = prompt_text + " "
|
279 |
text = [prompt_text + gt_text]
|
280 |
if tokenizer == "pinyin":
|
281 |
-
text_list = convert_char_to_pinyin(text, polyphone
|
282 |
else:
|
283 |
text_list = text
|
284 |
|
@@ -294,8 +306,8 @@ def get_inference_prompt(
|
|
294 |
# # test vocoder resynthesis
|
295 |
# ref_audio = gt_audio
|
296 |
else:
|
297 |
-
ref_text_len = len(prompt_text.encode(
|
298 |
-
gen_text_len = len(gt_text.encode(
|
299 |
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
300 |
|
301 |
# to mel spectrogram
|
@@ -304,8 +316,9 @@ def get_inference_prompt(
|
|
304 |
|
305 |
# deal with batch
|
306 |
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
307 |
-
assert
|
308 |
-
|
|
|
309 |
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
310 |
|
311 |
utts[bucket_i].append(utt)
|
@@ -319,28 +332,39 @@ def get_inference_prompt(
|
|
319 |
|
320 |
if batch_accum[bucket_i] >= infer_batch_size:
|
321 |
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
322 |
-
prompts_all.append(
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
|
|
|
|
330 |
batch_accum[bucket_i] = 0
|
331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
|
333 |
# add residual
|
334 |
for bucket_i, bucket_frames in enumerate(batch_accum):
|
335 |
if bucket_frames > 0:
|
336 |
-
prompts_all.append(
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
|
|
|
|
344 |
# not only leave easy work for last workers
|
345 |
random.seed(666)
|
346 |
random.shuffle(prompts_all)
|
@@ -351,6 +375,7 @@ def get_inference_prompt(
|
|
351 |
# get wav_res_ref_text of seed-tts test metalst
|
352 |
# https://github.com/BytedanceSpeech/seed-tts-eval
|
353 |
|
|
|
354 |
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
355 |
f = open(metalst)
|
356 |
lines = f.readlines()
|
@@ -358,14 +383,14 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
|
358 |
|
359 |
test_set_ = []
|
360 |
for line in tqdm(lines):
|
361 |
-
if len(line.strip().split(
|
362 |
-
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split(
|
363 |
-
elif len(line.strip().split(
|
364 |
-
utt, prompt_text, prompt_wav, gt_text = line.strip().split(
|
365 |
|
366 |
-
if not os.path.exists(os.path.join(gen_wav_dir, utt +
|
367 |
continue
|
368 |
-
gen_wav = os.path.join(gen_wav_dir, utt +
|
369 |
if not os.path.isabs(prompt_wav):
|
370 |
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
371 |
|
@@ -374,65 +399,69 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
|
374 |
num_jobs = len(gpus)
|
375 |
if num_jobs == 1:
|
376 |
return [(gpus[0], test_set_)]
|
377 |
-
|
378 |
wav_per_job = len(test_set_) // num_jobs + 1
|
379 |
test_set = []
|
380 |
for i in range(num_jobs):
|
381 |
-
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
|
382 |
|
383 |
return test_set
|
384 |
|
385 |
|
386 |
# get librispeech test-clean cross sentence test
|
387 |
|
388 |
-
|
|
|
389 |
f = open(metalst)
|
390 |
lines = f.readlines()
|
391 |
f.close()
|
392 |
|
393 |
test_set_ = []
|
394 |
for line in tqdm(lines):
|
395 |
-
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split(
|
396 |
|
397 |
if eval_ground_truth:
|
398 |
-
gen_spk_id, gen_chaptr_id, _ =
|
399 |
-
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt +
|
400 |
else:
|
401 |
-
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt +
|
402 |
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
|
403 |
-
gen_wav = os.path.join(gen_wav_dir, gen_utt +
|
404 |
|
405 |
-
ref_spk_id, ref_chaptr_id, _ =
|
406 |
-
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt +
|
407 |
|
408 |
test_set_.append((gen_wav, ref_wav, gen_txt))
|
409 |
|
410 |
num_jobs = len(gpus)
|
411 |
if num_jobs == 1:
|
412 |
return [(gpus[0], test_set_)]
|
413 |
-
|
414 |
wav_per_job = len(test_set_) // num_jobs + 1
|
415 |
test_set = []
|
416 |
for i in range(num_jobs):
|
417 |
-
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
|
418 |
|
419 |
return test_set
|
420 |
|
421 |
|
422 |
# load asr model
|
423 |
|
424 |
-
|
|
|
425 |
if lang == "zh":
|
426 |
from funasr import AutoModel
|
|
|
427 |
model = AutoModel(
|
428 |
-
model
|
429 |
-
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
|
430 |
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
|
431 |
-
# spk_model = os.path.join(ckpt_dir, "cam++"),
|
432 |
disable_update=True,
|
433 |
-
|
434 |
elif lang == "en":
|
435 |
from faster_whisper import WhisperModel
|
|
|
436 |
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
437 |
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
438 |
return model
|
@@ -440,44 +469,50 @@ def load_asr_model(lang, ckpt_dir = ""):
|
|
440 |
|
441 |
# WER Evaluation, the way Seed-TTS does
|
442 |
|
|
|
443 |
def run_asr_wer(args):
|
444 |
rank, lang, test_set, ckpt_dir = args
|
445 |
|
446 |
if lang == "zh":
|
447 |
import zhconv
|
|
|
448 |
torch.cuda.set_device(rank)
|
449 |
elif lang == "en":
|
450 |
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
451 |
else:
|
452 |
-
raise NotImplementedError(
|
|
|
|
|
|
|
|
|
453 |
|
454 |
-
asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
|
455 |
-
|
456 |
from zhon.hanzi import punctuation
|
|
|
457 |
punctuation_all = punctuation + string.punctuation
|
458 |
wers = []
|
459 |
|
460 |
from jiwer import compute_measures
|
|
|
461 |
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
462 |
if lang == "zh":
|
463 |
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
|
464 |
hypo = res[0]["text"]
|
465 |
-
hypo = zhconv.convert(hypo,
|
466 |
elif lang == "en":
|
467 |
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
|
468 |
-
hypo =
|
469 |
for segment in segments:
|
470 |
-
hypo = hypo +
|
471 |
|
472 |
# raw_truth = truth
|
473 |
# raw_hypo = hypo
|
474 |
|
475 |
for x in punctuation_all:
|
476 |
-
truth = truth.replace(x,
|
477 |
-
hypo = hypo.replace(x,
|
478 |
|
479 |
-
truth = truth.replace(
|
480 |
-
hypo = hypo.replace(
|
481 |
|
482 |
if lang == "zh":
|
483 |
truth = " ".join([x for x in truth])
|
@@ -501,22 +536,22 @@ def run_asr_wer(args):
|
|
501 |
|
502 |
# SIM Evaluation
|
503 |
|
|
|
504 |
def run_sim(args):
|
505 |
rank, test_set, ckpt_dir = args
|
506 |
device = f"cuda:{rank}"
|
507 |
|
508 |
-
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type=
|
509 |
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
|
510 |
-
model.load_state_dict(state_dict[
|
511 |
|
512 |
-
use_gpu=True if torch.cuda.is_available() else False
|
513 |
if use_gpu:
|
514 |
model = model.cuda(device)
|
515 |
model.eval()
|
516 |
|
517 |
sim_list = []
|
518 |
for wav1, wav2, truth in tqdm(test_set):
|
519 |
-
|
520 |
wav1, sr1 = torchaudio.load(wav1)
|
521 |
wav2, sr2 = torchaudio.load(wav2)
|
522 |
|
@@ -531,20 +566,21 @@ def run_sim(args):
|
|
531 |
with torch.no_grad():
|
532 |
emb1 = model(wav1)
|
533 |
emb2 = model(wav2)
|
534 |
-
|
535 |
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
536 |
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
537 |
sim_list.append(sim)
|
538 |
-
|
539 |
return sim_list
|
540 |
|
541 |
|
542 |
# filter func for dirty data with many repetitions
|
543 |
|
544 |
-
|
|
|
545 |
pattern_count = defaultdict(int)
|
546 |
for i in range(len(text) - length + 1):
|
547 |
-
pattern = text[i:i + length]
|
548 |
pattern_count[pattern] += 1
|
549 |
for pattern, count in pattern_count.items():
|
550 |
if count > tolerance:
|
@@ -554,25 +590,31 @@ def repetition_found(text, length = 2, tolerance = 10):
|
|
554 |
|
555 |
# load model checkpoint for inference
|
556 |
|
557 |
-
|
|
|
558 |
if device == "cuda":
|
559 |
model = model.half()
|
560 |
|
561 |
ckpt_type = ckpt_path.split(".")[-1]
|
562 |
if ckpt_type == "safetensors":
|
563 |
from safetensors.torch import load_file
|
|
|
564 |
checkpoint = load_file(ckpt_path)
|
565 |
else:
|
566 |
checkpoint = torch.load(ckpt_path, weights_only=True)
|
567 |
|
568 |
if use_ema:
|
569 |
if ckpt_type == "safetensors":
|
570 |
-
checkpoint = {
|
571 |
-
checkpoint[
|
572 |
-
|
|
|
|
|
|
|
|
|
573 |
else:
|
574 |
if ckpt_type == "safetensors":
|
575 |
-
checkpoint = {
|
576 |
-
model.load_state_dict(checkpoint[
|
577 |
|
578 |
return model.to(device)
|
|
|
8 |
from collections import defaultdict
|
9 |
|
10 |
import matplotlib
|
11 |
+
|
12 |
matplotlib.use("Agg")
|
13 |
import matplotlib.pylab as plt
|
14 |
|
|
|
26 |
|
27 |
# seed everything
|
28 |
|
29 |
+
|
30 |
+
def seed_everything(seed=0):
|
31 |
random.seed(seed)
|
32 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
33 |
torch.manual_seed(seed)
|
34 |
torch.cuda.manual_seed(seed)
|
35 |
torch.cuda.manual_seed_all(seed)
|
36 |
torch.backends.cudnn.deterministic = True
|
37 |
torch.backends.cudnn.benchmark = False
|
38 |
|
39 |
+
|
40 |
# helpers
|
41 |
|
42 |
+
|
43 |
def exists(v):
|
44 |
return v is not None
|
45 |
|
46 |
+
|
47 |
def default(v, d):
|
48 |
return v if exists(v) else d
|
49 |
|
50 |
+
|
51 |
# tensor helpers
|
52 |
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
|
55 |
if not exists(length):
|
56 |
length = t.amax()
|
57 |
|
58 |
+
seq = torch.arange(length, device=t.device)
|
59 |
return seq[None, :] < t[:, None]
|
60 |
|
61 |
+
|
62 |
+
def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
|
63 |
+
max_seq_len = seq_len.max().item()
|
64 |
+
seq = torch.arange(max_seq_len, device=start.device).long()
|
|
|
|
|
|
|
65 |
start_mask = seq[None, :] >= start[:, None]
|
66 |
end_mask = seq[None, :] < end[:, None]
|
67 |
return start_mask & end_mask
|
68 |
|
69 |
+
|
70 |
+
def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
|
|
|
|
|
71 |
lengths = (frac_lengths * seq_len).long()
|
72 |
max_start = seq_len - lengths
|
73 |
|
74 |
rand = torch.rand_like(frac_lengths)
|
75 |
+
start = (max_start * rand).long().clamp(min=0)
|
76 |
end = start + lengths
|
77 |
|
78 |
return mask_from_start_end_indices(seq_len, start, end)
|
79 |
|
|
|
|
|
|
|
|
|
80 |
|
81 |
+
def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
|
82 |
if not exists(mask):
|
83 |
+
return t.mean(dim=1)
|
84 |
|
85 |
+
t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
|
86 |
num = t.sum(dim=1)
|
87 |
den = mask.float().sum(dim=1)
|
88 |
|
89 |
+
return num / den.clamp(min=1.0)
|
90 |
|
91 |
|
92 |
# simple utf-8 tokenizer, since paper went character based
|
93 |
+
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
|
94 |
+
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
|
95 |
+
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
|
|
|
|
|
|
|
96 |
return text
|
97 |
|
98 |
+
|
99 |
# char tokenizer, based on custom dataset's extracted .txt file
|
100 |
def list_str_to_idx(
|
101 |
text: list[str] | list[list[str]],
|
102 |
vocab_char_map: dict[str, int], # {char: idx}
|
103 |
+
padding_value=-1,
|
104 |
+
) -> int["b nt"]: # noqa: F722
|
105 |
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
106 |
+
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
107 |
return text
|
108 |
|
109 |
|
110 |
# Get tokenizer
|
111 |
|
112 |
+
|
113 |
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
114 |
+
"""
|
115 |
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
116 |
- "char" for char-wise tokenizer, need .txt vocab_file
|
117 |
- "byte" for utf-8 tokenizer
|
118 |
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
119 |
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
120 |
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
121 |
+
- if use "byte", set to 256 (unicode byte range)
|
122 |
+
"""
|
123 |
if tokenizer in ["pinyin", "char"]:
|
124 |
+
with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
|
125 |
vocab_char_map = {}
|
126 |
for i, char in enumerate(f):
|
127 |
vocab_char_map[char[:-1]] = i
|
|
|
132 |
vocab_char_map = None
|
133 |
vocab_size = 256
|
134 |
elif tokenizer == "custom":
|
135 |
+
with open(dataset_name, "r", encoding="utf-8") as f:
|
136 |
vocab_char_map = {}
|
137 |
for i, char in enumerate(f):
|
138 |
vocab_char_map[char[:-1]] = i
|
|
|
143 |
|
144 |
# convert char to pinyin
|
145 |
|
146 |
+
|
147 |
+
def convert_char_to_pinyin(text_list, polyphone=True):
|
148 |
final_text_list = []
|
149 |
+
god_knows_why_en_testset_contains_zh_quote = str.maketrans(
|
150 |
+
{"“": '"', "”": '"', "‘": "'", "’": "'"}
|
151 |
+
) # in case librispeech (orig no-pc) test-clean
|
152 |
+
custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
|
153 |
for text in text_list:
|
154 |
char_list = []
|
155 |
text = text.translate(god_knows_why_en_testset_contains_zh_quote)
|
156 |
text = text.translate(custom_trans)
|
157 |
for seg in jieba.cut(text):
|
158 |
+
seg_byte_len = len(bytes(seg, "UTF-8"))
|
159 |
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
160 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
161 |
char_list.append(" ")
|
|
|
184 |
# save spectrogram
|
185 |
def save_spectrogram(spectrogram, path):
|
186 |
plt.figure(figsize=(12, 4))
|
187 |
+
plt.imshow(spectrogram, origin="lower", aspect="auto")
|
188 |
plt.colorbar()
|
189 |
plt.savefig(path)
|
190 |
plt.close()
|
|
|
192 |
|
193 |
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
194 |
def get_seedtts_testset_metainfo(metalst):
|
195 |
+
f = open(metalst)
|
196 |
+
lines = f.readlines()
|
197 |
+
f.close()
|
198 |
metainfo = []
|
199 |
for line in lines:
|
200 |
+
if len(line.strip().split("|")) == 5:
|
201 |
+
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
202 |
+
elif len(line.strip().split("|")) == 4:
|
203 |
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
204 |
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
205 |
if not os.path.isabs(prompt_wav):
|
206 |
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
|
|
210 |
|
211 |
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
|
212 |
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
|
213 |
+
f = open(metalst)
|
214 |
+
lines = f.readlines()
|
215 |
+
f.close()
|
216 |
metainfo = []
|
217 |
for line in lines:
|
218 |
+
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
|
219 |
|
220 |
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
221 |
+
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
222 |
+
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
|
223 |
|
224 |
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
225 |
+
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
226 |
+
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
|
227 |
|
228 |
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
|
229 |
|
|
|
235 |
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
236 |
padded_ref_mels = []
|
237 |
for mel in ref_mels:
|
238 |
+
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
239 |
padded_ref_mels.append(padded_ref_mel)
|
240 |
padded_ref_mels = torch.stack(padded_ref_mels)
|
241 |
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
|
|
244 |
|
245 |
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
246 |
|
247 |
+
|
248 |
def get_inference_prompt(
|
249 |
+
metainfo,
|
250 |
+
speed=1.0,
|
251 |
+
tokenizer="pinyin",
|
252 |
+
polyphone=True,
|
253 |
+
target_sample_rate=24000,
|
254 |
+
n_mel_channels=100,
|
255 |
+
hop_length=256,
|
256 |
+
target_rms=0.1,
|
257 |
+
use_truth_duration=False,
|
258 |
+
infer_batch_size=1,
|
259 |
+
num_buckets=200,
|
260 |
+
min_secs=3,
|
261 |
+
max_secs=40,
|
262 |
):
|
263 |
prompts_all = []
|
264 |
|
|
|
266 |
max_tokens = max_secs * target_sample_rate // hop_length
|
267 |
|
268 |
batch_accum = [0] * num_buckets
|
269 |
+
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
270 |
+
[[] for _ in range(num_buckets)] for _ in range(6)
|
271 |
+
)
|
272 |
|
273 |
+
mel_spectrogram = MelSpec(
|
274 |
+
target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
|
275 |
+
)
|
276 |
|
277 |
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
|
|
|
278 |
# Audio
|
279 |
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
280 |
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
|
|
286 |
ref_audio = resampler(ref_audio)
|
287 |
|
288 |
# Text
|
289 |
+
if len(prompt_text[-1].encode("utf-8")) == 1:
|
290 |
prompt_text = prompt_text + " "
|
291 |
text = [prompt_text + gt_text]
|
292 |
if tokenizer == "pinyin":
|
293 |
+
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
|
294 |
else:
|
295 |
text_list = text
|
296 |
|
|
|
306 |
# # test vocoder resynthesis
|
307 |
# ref_audio = gt_audio
|
308 |
else:
|
309 |
+
ref_text_len = len(prompt_text.encode("utf-8"))
|
310 |
+
gen_text_len = len(gt_text.encode("utf-8"))
|
311 |
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
312 |
|
313 |
# to mel spectrogram
|
|
|
316 |
|
317 |
# deal with batch
|
318 |
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
319 |
+
assert (
|
320 |
+
min_tokens <= total_mel_len <= max_tokens
|
321 |
+
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
322 |
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
323 |
|
324 |
utts[bucket_i].append(utt)
|
|
|
332 |
|
333 |
if batch_accum[bucket_i] >= infer_batch_size:
|
334 |
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
335 |
+
prompts_all.append(
|
336 |
+
(
|
337 |
+
utts[bucket_i],
|
338 |
+
ref_rms_list[bucket_i],
|
339 |
+
padded_mel_batch(ref_mels[bucket_i]),
|
340 |
+
ref_mel_lens[bucket_i],
|
341 |
+
total_mel_lens[bucket_i],
|
342 |
+
final_text_list[bucket_i],
|
343 |
+
)
|
344 |
+
)
|
345 |
batch_accum[bucket_i] = 0
|
346 |
+
(
|
347 |
+
utts[bucket_i],
|
348 |
+
ref_rms_list[bucket_i],
|
349 |
+
ref_mels[bucket_i],
|
350 |
+
ref_mel_lens[bucket_i],
|
351 |
+
total_mel_lens[bucket_i],
|
352 |
+
final_text_list[bucket_i],
|
353 |
+
) = [], [], [], [], [], []
|
354 |
|
355 |
# add residual
|
356 |
for bucket_i, bucket_frames in enumerate(batch_accum):
|
357 |
if bucket_frames > 0:
|
358 |
+
prompts_all.append(
|
359 |
+
(
|
360 |
+
utts[bucket_i],
|
361 |
+
ref_rms_list[bucket_i],
|
362 |
+
padded_mel_batch(ref_mels[bucket_i]),
|
363 |
+
ref_mel_lens[bucket_i],
|
364 |
+
total_mel_lens[bucket_i],
|
365 |
+
final_text_list[bucket_i],
|
366 |
+
)
|
367 |
+
)
|
368 |
# not only leave easy work for last workers
|
369 |
random.seed(666)
|
370 |
random.shuffle(prompts_all)
|
|
|
375 |
# get wav_res_ref_text of seed-tts test metalst
|
376 |
# https://github.com/BytedanceSpeech/seed-tts-eval
|
377 |
|
378 |
+
|
379 |
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
380 |
f = open(metalst)
|
381 |
lines = f.readlines()
|
|
|
383 |
|
384 |
test_set_ = []
|
385 |
for line in tqdm(lines):
|
386 |
+
if len(line.strip().split("|")) == 5:
|
387 |
+
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
388 |
+
elif len(line.strip().split("|")) == 4:
|
389 |
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
390 |
|
391 |
+
if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
|
392 |
continue
|
393 |
+
gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
|
394 |
if not os.path.isabs(prompt_wav):
|
395 |
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
396 |
|
|
|
399 |
num_jobs = len(gpus)
|
400 |
if num_jobs == 1:
|
401 |
return [(gpus[0], test_set_)]
|
402 |
+
|
403 |
wav_per_job = len(test_set_) // num_jobs + 1
|
404 |
test_set = []
|
405 |
for i in range(num_jobs):
|
406 |
+
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
|
407 |
|
408 |
return test_set
|
409 |
|
410 |
|
411 |
# get librispeech test-clean cross sentence test
|
412 |
|
413 |
+
|
414 |
+
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
|
415 |
f = open(metalst)
|
416 |
lines = f.readlines()
|
417 |
f.close()
|
418 |
|
419 |
test_set_ = []
|
420 |
for line in tqdm(lines):
|
421 |
+
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
|
422 |
|
423 |
if eval_ground_truth:
|
424 |
+
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
425 |
+
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
|
426 |
else:
|
427 |
+
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
|
428 |
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
|
429 |
+
gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
|
430 |
|
431 |
+
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
432 |
+
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
|
433 |
|
434 |
test_set_.append((gen_wav, ref_wav, gen_txt))
|
435 |
|
436 |
num_jobs = len(gpus)
|
437 |
if num_jobs == 1:
|
438 |
return [(gpus[0], test_set_)]
|
439 |
+
|
440 |
wav_per_job = len(test_set_) // num_jobs + 1
|
441 |
test_set = []
|
442 |
for i in range(num_jobs):
|
443 |
+
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
|
444 |
|
445 |
return test_set
|
446 |
|
447 |
|
448 |
# load asr model
|
449 |
|
450 |
+
|
451 |
+
def load_asr_model(lang, ckpt_dir=""):
|
452 |
if lang == "zh":
|
453 |
from funasr import AutoModel
|
454 |
+
|
455 |
model = AutoModel(
|
456 |
+
model=os.path.join(ckpt_dir, "paraformer-zh"),
|
457 |
+
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
|
458 |
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
|
459 |
+
# spk_model = os.path.join(ckpt_dir, "cam++"),
|
460 |
disable_update=True,
|
461 |
+
) # following seed-tts setting
|
462 |
elif lang == "en":
|
463 |
from faster_whisper import WhisperModel
|
464 |
+
|
465 |
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
466 |
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
467 |
return model
|
|
|
469 |
|
470 |
# WER Evaluation, the way Seed-TTS does
|
471 |
|
472 |
+
|
473 |
def run_asr_wer(args):
|
474 |
rank, lang, test_set, ckpt_dir = args
|
475 |
|
476 |
if lang == "zh":
|
477 |
import zhconv
|
478 |
+
|
479 |
torch.cuda.set_device(rank)
|
480 |
elif lang == "en":
|
481 |
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
482 |
else:
|
483 |
+
raise NotImplementedError(
|
484 |
+
"lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
|
485 |
+
)
|
486 |
+
|
487 |
+
asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
|
488 |
|
|
|
|
|
489 |
from zhon.hanzi import punctuation
|
490 |
+
|
491 |
punctuation_all = punctuation + string.punctuation
|
492 |
wers = []
|
493 |
|
494 |
from jiwer import compute_measures
|
495 |
+
|
496 |
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
497 |
if lang == "zh":
|
498 |
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
|
499 |
hypo = res[0]["text"]
|
500 |
+
hypo = zhconv.convert(hypo, "zh-cn")
|
501 |
elif lang == "en":
|
502 |
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
|
503 |
+
hypo = ""
|
504 |
for segment in segments:
|
505 |
+
hypo = hypo + " " + segment.text
|
506 |
|
507 |
# raw_truth = truth
|
508 |
# raw_hypo = hypo
|
509 |
|
510 |
for x in punctuation_all:
|
511 |
+
truth = truth.replace(x, "")
|
512 |
+
hypo = hypo.replace(x, "")
|
513 |
|
514 |
+
truth = truth.replace(" ", " ")
|
515 |
+
hypo = hypo.replace(" ", " ")
|
516 |
|
517 |
if lang == "zh":
|
518 |
truth = " ".join([x for x in truth])
|
|
|
536 |
|
537 |
# SIM Evaluation
|
538 |
|
539 |
+
|
540 |
def run_sim(args):
|
541 |
rank, test_set, ckpt_dir = args
|
542 |
device = f"cuda:{rank}"
|
543 |
|
544 |
+
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
|
545 |
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
|
546 |
+
model.load_state_dict(state_dict["model"], strict=False)
|
547 |
|
548 |
+
use_gpu = True if torch.cuda.is_available() else False
|
549 |
if use_gpu:
|
550 |
model = model.cuda(device)
|
551 |
model.eval()
|
552 |
|
553 |
sim_list = []
|
554 |
for wav1, wav2, truth in tqdm(test_set):
|
|
|
555 |
wav1, sr1 = torchaudio.load(wav1)
|
556 |
wav2, sr2 = torchaudio.load(wav2)
|
557 |
|
|
|
566 |
with torch.no_grad():
|
567 |
emb1 = model(wav1)
|
568 |
emb2 = model(wav2)
|
569 |
+
|
570 |
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
571 |
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
572 |
sim_list.append(sim)
|
573 |
+
|
574 |
return sim_list
|
575 |
|
576 |
|
577 |
# filter func for dirty data with many repetitions
|
578 |
|
579 |
+
|
580 |
+
def repetition_found(text, length=2, tolerance=10):
|
581 |
pattern_count = defaultdict(int)
|
582 |
for i in range(len(text) - length + 1):
|
583 |
+
pattern = text[i : i + length]
|
584 |
pattern_count[pattern] += 1
|
585 |
for pattern, count in pattern_count.items():
|
586 |
if count > tolerance:
|
|
|
590 |
|
591 |
# load model checkpoint for inference
|
592 |
|
593 |
+
|
594 |
+
def load_checkpoint(model, ckpt_path, device, use_ema=True):
|
595 |
if device == "cuda":
|
596 |
model = model.half()
|
597 |
|
598 |
ckpt_type = ckpt_path.split(".")[-1]
|
599 |
if ckpt_type == "safetensors":
|
600 |
from safetensors.torch import load_file
|
601 |
+
|
602 |
checkpoint = load_file(ckpt_path)
|
603 |
else:
|
604 |
checkpoint = torch.load(ckpt_path, weights_only=True)
|
605 |
|
606 |
if use_ema:
|
607 |
if ckpt_type == "safetensors":
|
608 |
+
checkpoint = {"ema_model_state_dict": checkpoint}
|
609 |
+
checkpoint["model_state_dict"] = {
|
610 |
+
k.replace("ema_model.", ""): v
|
611 |
+
for k, v in checkpoint["ema_model_state_dict"].items()
|
612 |
+
if k not in ["initted", "step"]
|
613 |
+
}
|
614 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
615 |
else:
|
616 |
if ckpt_type == "safetensors":
|
617 |
+
checkpoint = {"model_state_dict": checkpoint}
|
618 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
619 |
|
620 |
return model.to(device)
|
model/utils_infer.py
CHANGED
@@ -19,11 +19,7 @@ from model.utils import (
|
|
19 |
convert_char_to_pinyin,
|
20 |
)
|
21 |
|
22 |
-
device = (
|
23 |
-
"cuda"
|
24 |
-
if torch.cuda.is_available()
|
25 |
-
else "mps" if torch.backends.mps.is_available() else "cpu"
|
26 |
-
)
|
27 |
print(f"Using {device} device")
|
28 |
|
29 |
asr_pipe = pipeline(
|
@@ -54,6 +50,7 @@ fix_duration = None
|
|
54 |
|
55 |
# chunk text into smaller pieces
|
56 |
|
|
|
57 |
def chunk_text(text, max_chars=135):
|
58 |
"""
|
59 |
Splits the input text into chunks, each with a maximum number of characters.
|
@@ -68,15 +65,15 @@ def chunk_text(text, max_chars=135):
|
|
68 |
chunks = []
|
69 |
current_chunk = ""
|
70 |
# Split the text into sentences based on punctuation followed by whitespace
|
71 |
-
sentences = re.split(r
|
72 |
|
73 |
for sentence in sentences:
|
74 |
-
if len(current_chunk.encode(
|
75 |
-
current_chunk += sentence + " " if sentence and len(sentence[-1].encode(
|
76 |
else:
|
77 |
if current_chunk:
|
78 |
chunks.append(current_chunk.strip())
|
79 |
-
current_chunk = sentence + " " if sentence and len(sentence[-1].encode(
|
80 |
|
81 |
if current_chunk:
|
82 |
chunks.append(current_chunk.strip())
|
@@ -86,6 +83,7 @@ def chunk_text(text, max_chars=135):
|
|
86 |
|
87 |
# load vocoder
|
88 |
|
|
|
89 |
def load_vocoder(is_local=False, local_path=""):
|
90 |
if is_local:
|
91 |
print(f"Load vocos from local path {local_path}")
|
@@ -101,23 +99,21 @@ def load_vocoder(is_local=False, local_path=""):
|
|
101 |
|
102 |
# load model for inference
|
103 |
|
|
|
104 |
def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
|
105 |
-
|
106 |
if vocab_file == "":
|
107 |
vocab_file = "Emilia_ZH_EN"
|
108 |
tokenizer = "pinyin"
|
109 |
else:
|
110 |
tokenizer = "custom"
|
111 |
|
112 |
-
print("\nvocab : ", vocab_file, tokenizer)
|
113 |
-
print("tokenizer : ", tokenizer)
|
114 |
-
print("model : ", ckpt_path,"\n")
|
115 |
|
116 |
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
|
117 |
model = CFM(
|
118 |
-
transformer=model_cls(
|
119 |
-
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
120 |
-
),
|
121 |
mel_spec_kwargs=dict(
|
122 |
target_sample_rate=target_sample_rate,
|
123 |
n_mel_channels=n_mel_channels,
|
@@ -129,21 +125,20 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
|
|
129 |
vocab_char_map=vocab_char_map,
|
130 |
).to(device)
|
131 |
|
132 |
-
model = load_checkpoint(model, ckpt_path, device, use_ema
|
133 |
|
134 |
return model
|
135 |
|
136 |
|
137 |
# preprocess reference audio and text
|
138 |
|
|
|
139 |
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
140 |
show_info("Converting audio...")
|
141 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
142 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
143 |
|
144 |
-
non_silent_segs = silence.split_on_silence(
|
145 |
-
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
|
146 |
-
)
|
147 |
non_silent_wave = AudioSegment.silent(duration=0)
|
148 |
for non_silent_seg in non_silent_segs:
|
149 |
non_silent_wave += non_silent_seg
|
@@ -181,22 +176,27 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
|
181 |
|
182 |
# infer process: chunk text -> infer batches [i.e. infer_batch_process()]
|
183 |
|
184 |
-
def infer_process(ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm):
|
185 |
|
|
|
|
|
|
|
186 |
# Split the input text into batches
|
187 |
audio, sr = torchaudio.load(ref_audio)
|
188 |
-
max_chars = int(len(ref_text.encode(
|
189 |
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
190 |
for i, gen_text in enumerate(gen_text_batches):
|
191 |
-
print(f
|
192 |
-
|
193 |
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
194 |
return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress)
|
195 |
|
196 |
|
197 |
# infer batches
|
198 |
|
199 |
-
|
|
|
|
|
|
|
200 |
audio, sr = ref_audio
|
201 |
if audio.shape[0] > 1:
|
202 |
audio = torch.mean(audio, dim=0, keepdim=True)
|
@@ -212,7 +212,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
|
|
212 |
generated_waves = []
|
213 |
spectrograms = []
|
214 |
|
215 |
-
if len(ref_text[-1].encode(
|
216 |
ref_text = ref_text + " "
|
217 |
for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
|
218 |
# Prepare the text
|
@@ -221,8 +221,8 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
|
|
221 |
|
222 |
# Calculate duration
|
223 |
ref_audio_len = audio.shape[-1] // hop_length
|
224 |
-
ref_text_len = len(ref_text.encode(
|
225 |
-
gen_text_len = len(gen_text.encode(
|
226 |
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
227 |
|
228 |
# inference
|
@@ -245,7 +245,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
|
|
245 |
|
246 |
# wav -> numpy
|
247 |
generated_wave = generated_wave.squeeze().cpu().numpy()
|
248 |
-
|
249 |
generated_waves.append(generated_wave)
|
250 |
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
251 |
|
@@ -280,11 +280,9 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
|
|
280 |
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
281 |
|
282 |
# Combine
|
283 |
-
new_wave = np.concatenate(
|
284 |
-
prev_wave[:-cross_fade_samples],
|
285 |
-
|
286 |
-
next_wave[cross_fade_samples:]
|
287 |
-
])
|
288 |
|
289 |
final_wave = new_wave
|
290 |
|
@@ -296,6 +294,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
|
|
296 |
|
297 |
# remove silence from generated wav
|
298 |
|
|
|
299 |
def remove_silence_for_generated_wav(filename):
|
300 |
aseg = AudioSegment.from_file(filename)
|
301 |
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
|
|
|
19 |
convert_char_to_pinyin,
|
20 |
)
|
21 |
|
22 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
23 |
print(f"Using {device} device")
|
24 |
|
25 |
asr_pipe = pipeline(
|
|
|
50 |
|
51 |
# chunk text into smaller pieces
|
52 |
|
53 |
+
|
54 |
def chunk_text(text, max_chars=135):
|
55 |
"""
|
56 |
Splits the input text into chunks, each with a maximum number of characters.
|
|
|
65 |
chunks = []
|
66 |
current_chunk = ""
|
67 |
# Split the text into sentences based on punctuation followed by whitespace
|
68 |
+
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
69 |
|
70 |
for sentence in sentences:
|
71 |
+
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
72 |
+
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
73 |
else:
|
74 |
if current_chunk:
|
75 |
chunks.append(current_chunk.strip())
|
76 |
+
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
77 |
|
78 |
if current_chunk:
|
79 |
chunks.append(current_chunk.strip())
|
|
|
83 |
|
84 |
# load vocoder
|
85 |
|
86 |
+
|
87 |
def load_vocoder(is_local=False, local_path=""):
|
88 |
if is_local:
|
89 |
print(f"Load vocos from local path {local_path}")
|
|
|
99 |
|
100 |
# load model for inference
|
101 |
|
102 |
+
|
103 |
def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
|
|
|
104 |
if vocab_file == "":
|
105 |
vocab_file = "Emilia_ZH_EN"
|
106 |
tokenizer = "pinyin"
|
107 |
else:
|
108 |
tokenizer = "custom"
|
109 |
|
110 |
+
print("\nvocab : ", vocab_file, tokenizer)
|
111 |
+
print("tokenizer : ", tokenizer)
|
112 |
+
print("model : ", ckpt_path, "\n")
|
113 |
|
114 |
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
|
115 |
model = CFM(
|
116 |
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
|
|
|
|
117 |
mel_spec_kwargs=dict(
|
118 |
target_sample_rate=target_sample_rate,
|
119 |
n_mel_channels=n_mel_channels,
|
|
|
125 |
vocab_char_map=vocab_char_map,
|
126 |
).to(device)
|
127 |
|
128 |
+
model = load_checkpoint(model, ckpt_path, device, use_ema=True)
|
129 |
|
130 |
return model
|
131 |
|
132 |
|
133 |
# preprocess reference audio and text
|
134 |
|
135 |
+
|
136 |
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
137 |
show_info("Converting audio...")
|
138 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
139 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
140 |
|
141 |
+
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
|
|
|
|
|
142 |
non_silent_wave = AudioSegment.silent(duration=0)
|
143 |
for non_silent_seg in non_silent_segs:
|
144 |
non_silent_wave += non_silent_seg
|
|
|
176 |
|
177 |
# infer process: chunk text -> infer batches [i.e. infer_batch_process()]
|
178 |
|
|
|
179 |
|
180 |
+
def infer_process(
|
181 |
+
ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm
|
182 |
+
):
|
183 |
# Split the input text into batches
|
184 |
audio, sr = torchaudio.load(ref_audio)
|
185 |
+
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
186 |
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
187 |
for i, gen_text in enumerate(gen_text_batches):
|
188 |
+
print(f"gen_text {i}", gen_text)
|
189 |
+
|
190 |
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
191 |
return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress)
|
192 |
|
193 |
|
194 |
# infer batches
|
195 |
|
196 |
+
|
197 |
+
def infer_batch_process(
|
198 |
+
ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm
|
199 |
+
):
|
200 |
audio, sr = ref_audio
|
201 |
if audio.shape[0] > 1:
|
202 |
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
|
212 |
generated_waves = []
|
213 |
spectrograms = []
|
214 |
|
215 |
+
if len(ref_text[-1].encode("utf-8")) == 1:
|
216 |
ref_text = ref_text + " "
|
217 |
for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
|
218 |
# Prepare the text
|
|
|
221 |
|
222 |
# Calculate duration
|
223 |
ref_audio_len = audio.shape[-1] // hop_length
|
224 |
+
ref_text_len = len(ref_text.encode("utf-8"))
|
225 |
+
gen_text_len = len(gen_text.encode("utf-8"))
|
226 |
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
227 |
|
228 |
# inference
|
|
|
245 |
|
246 |
# wav -> numpy
|
247 |
generated_wave = generated_wave.squeeze().cpu().numpy()
|
248 |
+
|
249 |
generated_waves.append(generated_wave)
|
250 |
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
251 |
|
|
|
280 |
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
281 |
|
282 |
# Combine
|
283 |
+
new_wave = np.concatenate(
|
284 |
+
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
285 |
+
)
|
|
|
|
|
286 |
|
287 |
final_wave = new_wave
|
288 |
|
|
|
294 |
|
295 |
# remove silence from generated wav
|
296 |
|
297 |
+
|
298 |
def remove_silence_for_generated_wav(filename):
|
299 |
aseg = AudioSegment.from_file(filename)
|
300 |
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
|
ruff.toml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
line-length = 120
|
2 |
+
target-version = "py310"
|
3 |
+
|
4 |
+
[lint]
|
5 |
+
# Only ignore variables with names starting with "_".
|
6 |
+
dummy-variable-rgx = "^_.*$"
|
7 |
+
|
8 |
+
[lint.isort]
|
9 |
+
force-single-line = true
|
10 |
+
lines-after-imports = 2
|
scripts/count_max_epoch.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
print(
|
|
|
4 |
|
5 |
# data
|
6 |
total_hours = 95282
|
|
|
1 |
+
"""ADAPTIVE BATCH SIZE"""
|
2 |
+
|
3 |
+
print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
|
4 |
+
print(" -> least padding, gather wavs with accumulated frames in a batch\n")
|
5 |
|
6 |
# data
|
7 |
total_hours = 95282
|
scripts/count_params_gflops.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
-
import sys
|
|
|
|
|
2 |
sys.path.append(os.getcwd())
|
3 |
|
4 |
-
from model import M2_TTS,
|
5 |
|
6 |
import torch
|
7 |
import thop
|
8 |
|
9 |
|
10 |
-
|
11 |
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
|
12 |
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
|
13 |
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
|
@@ -15,11 +17,11 @@ import thop
|
|
15 |
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
|
16 |
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
|
17 |
|
18 |
-
|
19 |
# FLOPs: 622.1 G, Params: 333.2 M
|
20 |
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
21 |
# FLOPs: 363.4 G, Params: 335.8 M
|
22 |
-
transformer =
|
23 |
|
24 |
|
25 |
model = M2_TTS(transformer=transformer)
|
@@ -30,6 +32,8 @@ duration = 20
|
|
30 |
frame_length = int(duration * target_sample_rate / hop_length)
|
31 |
text_length = 150
|
32 |
|
33 |
-
flops, params = thop.profile(
|
|
|
|
|
34 |
print(f"FLOPs: {flops / 1e9} G")
|
35 |
print(f"Params: {params / 1e6} M")
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
+
from model import M2_TTS, DiT
|
7 |
|
8 |
import torch
|
9 |
import thop
|
10 |
|
11 |
|
12 |
+
""" ~155M """
|
13 |
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
|
14 |
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
|
15 |
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
|
|
|
17 |
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
|
18 |
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
|
19 |
|
20 |
+
""" ~335M """
|
21 |
# FLOPs: 622.1 G, Params: 333.2 M
|
22 |
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
23 |
# FLOPs: 363.4 G, Params: 335.8 M
|
24 |
+
transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
25 |
|
26 |
|
27 |
model = M2_TTS(transformer=transformer)
|
|
|
32 |
frame_length = int(duration * target_sample_rate / hop_length)
|
33 |
text_length = 150
|
34 |
|
35 |
+
flops, params = thop.profile(
|
36 |
+
model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
|
37 |
+
)
|
38 |
print(f"FLOPs: {flops / 1e9} G")
|
39 |
print(f"Params: {params / 1e6} M")
|
scripts/eval_infer_batch.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
import sys
|
|
|
|
|
2 |
sys.path.append(os.getcwd())
|
3 |
|
4 |
import time
|
@@ -14,9 +16,9 @@ from vocos import Vocos
|
|
14 |
from model import CFM, UNetT, DiT
|
15 |
from model.utils import (
|
16 |
load_checkpoint,
|
17 |
-
get_tokenizer,
|
18 |
-
get_seedtts_testset_metainfo,
|
19 |
-
get_librispeech_test_clean_metainfo,
|
20 |
get_inference_prompt,
|
21 |
)
|
22 |
|
@@ -38,16 +40,16 @@ tokenizer = "pinyin"
|
|
38 |
|
39 |
parser = argparse.ArgumentParser(description="batch inference")
|
40 |
|
41 |
-
parser.add_argument(
|
42 |
-
parser.add_argument(
|
43 |
-
parser.add_argument(
|
44 |
-
parser.add_argument(
|
45 |
|
46 |
-
parser.add_argument(
|
47 |
-
parser.add_argument(
|
48 |
-
parser.add_argument(
|
49 |
|
50 |
-
parser.add_argument(
|
51 |
|
52 |
args = parser.parse_args()
|
53 |
|
@@ -66,26 +68,26 @@ testset = args.testset
|
|
66 |
|
67 |
|
68 |
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
69 |
-
cfg_strength = 2.
|
70 |
-
speed = 1.
|
71 |
use_truth_duration = False
|
72 |
no_ref_audio = False
|
73 |
|
74 |
|
75 |
if exp_name == "F5TTS_Base":
|
76 |
model_cls = DiT
|
77 |
-
model_cfg = dict(dim
|
78 |
|
79 |
elif exp_name == "E2TTS_Base":
|
80 |
model_cls = UNetT
|
81 |
-
model_cfg = dict(dim
|
82 |
|
83 |
|
84 |
if testset == "ls_pc_test_clean":
|
85 |
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
86 |
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
87 |
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
88 |
-
|
89 |
elif testset == "seedtts_test_zh":
|
90 |
metalst = "data/seedtts_testset/zh/meta.lst"
|
91 |
metainfo = get_seedtts_testset_metainfo(metalst)
|
@@ -96,13 +98,16 @@ elif testset == "seedtts_test_en":
|
|
96 |
|
97 |
|
98 |
# path to save genereted wavs
|
99 |
-
if seed is None:
|
100 |
-
|
101 |
-
|
102 |
-
f"{
|
103 |
-
f"
|
104 |
-
f"{'
|
|
|
|
|
105 |
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
|
|
106 |
|
107 |
|
108 |
# -------------------------------------------------#
|
@@ -110,15 +115,15 @@ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
|
|
110 |
use_ema = True
|
111 |
|
112 |
prompts_all = get_inference_prompt(
|
113 |
-
metainfo,
|
114 |
-
speed
|
115 |
-
tokenizer
|
116 |
-
target_sample_rate
|
117 |
-
n_mel_channels
|
118 |
-
hop_length
|
119 |
-
target_rms
|
120 |
-
use_truth_duration
|
121 |
-
infer_batch_size
|
122 |
)
|
123 |
|
124 |
# Vocoder model
|
@@ -137,23 +142,19 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
|
137 |
|
138 |
# Model
|
139 |
model = CFM(
|
140 |
-
transformer =
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
144 |
),
|
145 |
-
|
146 |
-
|
147 |
-
n_mel_channels = n_mel_channels,
|
148 |
-
hop_length = hop_length,
|
149 |
),
|
150 |
-
|
151 |
-
method = ode_method,
|
152 |
-
),
|
153 |
-
vocab_char_map = vocab_char_map,
|
154 |
).to(device)
|
155 |
|
156 |
-
model = load_checkpoint(model, ckpt_path, device, use_ema
|
157 |
|
158 |
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
159 |
os.makedirs(output_dir)
|
@@ -163,29 +164,28 @@ accelerator.wait_for_everyone()
|
|
163 |
start = time.time()
|
164 |
|
165 |
with accelerator.split_between_processes(prompts_all) as prompts:
|
166 |
-
|
167 |
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
168 |
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
169 |
ref_mels = ref_mels.to(device)
|
170 |
-
ref_mel_lens = torch.tensor(ref_mel_lens, dtype
|
171 |
-
total_mel_lens = torch.tensor(total_mel_lens, dtype
|
172 |
-
|
173 |
# Inference
|
174 |
with torch.inference_mode():
|
175 |
generated, _ = model.sample(
|
176 |
-
cond
|
177 |
-
text
|
178 |
-
duration
|
179 |
-
lens
|
180 |
-
steps
|
181 |
-
cfg_strength
|
182 |
-
sway_sampling_coef
|
183 |
-
no_ref_audio
|
184 |
-
seed
|
185 |
)
|
186 |
# Final result
|
187 |
for i, gen in enumerate(generated):
|
188 |
-
gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
|
189 |
gen_mel_spec = gen.permute(0, 2, 1)
|
190 |
generated_wave = vocos.decode(gen_mel_spec.cpu())
|
191 |
if ref_rms_list[i] < target_rms:
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
import time
|
|
|
16 |
from model import CFM, UNetT, DiT
|
17 |
from model.utils import (
|
18 |
load_checkpoint,
|
19 |
+
get_tokenizer,
|
20 |
+
get_seedtts_testset_metainfo,
|
21 |
+
get_librispeech_test_clean_metainfo,
|
22 |
get_inference_prompt,
|
23 |
)
|
24 |
|
|
|
40 |
|
41 |
parser = argparse.ArgumentParser(description="batch inference")
|
42 |
|
43 |
+
parser.add_argument("-s", "--seed", default=None, type=int)
|
44 |
+
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
45 |
+
parser.add_argument("-n", "--expname", required=True)
|
46 |
+
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
47 |
|
48 |
+
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
49 |
+
parser.add_argument("-o", "--odemethod", default="euler")
|
50 |
+
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
51 |
|
52 |
+
parser.add_argument("-t", "--testset", required=True)
|
53 |
|
54 |
args = parser.parse_args()
|
55 |
|
|
|
68 |
|
69 |
|
70 |
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
71 |
+
cfg_strength = 2.0
|
72 |
+
speed = 1.0
|
73 |
use_truth_duration = False
|
74 |
no_ref_audio = False
|
75 |
|
76 |
|
77 |
if exp_name == "F5TTS_Base":
|
78 |
model_cls = DiT
|
79 |
+
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
80 |
|
81 |
elif exp_name == "E2TTS_Base":
|
82 |
model_cls = UNetT
|
83 |
+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
84 |
|
85 |
|
86 |
if testset == "ls_pc_test_clean":
|
87 |
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
88 |
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
89 |
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
90 |
+
|
91 |
elif testset == "seedtts_test_zh":
|
92 |
metalst = "data/seedtts_testset/zh/meta.lst"
|
93 |
metainfo = get_seedtts_testset_metainfo(metalst)
|
|
|
98 |
|
99 |
|
100 |
# path to save genereted wavs
|
101 |
+
if seed is None:
|
102 |
+
seed = random.randint(-10000, 10000)
|
103 |
+
output_dir = (
|
104 |
+
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
105 |
+
f"seed{seed}_{ode_method}_nfe{nfe_step}"
|
106 |
+
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
107 |
+
f"_cfg{cfg_strength}_speed{speed}"
|
108 |
+
f"{'_gt-dur' if use_truth_duration else ''}"
|
109 |
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
110 |
+
)
|
111 |
|
112 |
|
113 |
# -------------------------------------------------#
|
|
|
115 |
use_ema = True
|
116 |
|
117 |
prompts_all = get_inference_prompt(
|
118 |
+
metainfo,
|
119 |
+
speed=speed,
|
120 |
+
tokenizer=tokenizer,
|
121 |
+
target_sample_rate=target_sample_rate,
|
122 |
+
n_mel_channels=n_mel_channels,
|
123 |
+
hop_length=hop_length,
|
124 |
+
target_rms=target_rms,
|
125 |
+
use_truth_duration=use_truth_duration,
|
126 |
+
infer_batch_size=infer_batch_size,
|
127 |
)
|
128 |
|
129 |
# Vocoder model
|
|
|
142 |
|
143 |
# Model
|
144 |
model = CFM(
|
145 |
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
146 |
+
mel_spec_kwargs=dict(
|
147 |
+
target_sample_rate=target_sample_rate,
|
148 |
+
n_mel_channels=n_mel_channels,
|
149 |
+
hop_length=hop_length,
|
150 |
),
|
151 |
+
odeint_kwargs=dict(
|
152 |
+
method=ode_method,
|
|
|
|
|
153 |
),
|
154 |
+
vocab_char_map=vocab_char_map,
|
|
|
|
|
|
|
155 |
).to(device)
|
156 |
|
157 |
+
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
|
158 |
|
159 |
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
160 |
os.makedirs(output_dir)
|
|
|
164 |
start = time.time()
|
165 |
|
166 |
with accelerator.split_between_processes(prompts_all) as prompts:
|
|
|
167 |
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
168 |
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
169 |
ref_mels = ref_mels.to(device)
|
170 |
+
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
171 |
+
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
172 |
+
|
173 |
# Inference
|
174 |
with torch.inference_mode():
|
175 |
generated, _ = model.sample(
|
176 |
+
cond=ref_mels,
|
177 |
+
text=final_text_list,
|
178 |
+
duration=total_mel_lens,
|
179 |
+
lens=ref_mel_lens,
|
180 |
+
steps=nfe_step,
|
181 |
+
cfg_strength=cfg_strength,
|
182 |
+
sway_sampling_coef=sway_sampling_coef,
|
183 |
+
no_ref_audio=no_ref_audio,
|
184 |
+
seed=seed,
|
185 |
)
|
186 |
# Final result
|
187 |
for i, gen in enumerate(generated):
|
188 |
+
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
189 |
gen_mel_spec = gen.permute(0, 2, 1)
|
190 |
generated_wave = vocos.decode(gen_mel_spec.cpu())
|
191 |
if ref_rms_list[i] < target_rms:
|
scripts/eval_librispeech_test_clean.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
2 |
|
3 |
-
import sys
|
|
|
|
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
import multiprocessing as mp
|
@@ -19,7 +21,7 @@ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
|
19 |
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
20 |
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
|
21 |
|
22 |
-
gpus = [0,1,2,3,4,5,6,7]
|
23 |
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
24 |
|
25 |
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
@@ -46,7 +48,7 @@ if eval_task == "wer":
|
|
46 |
for wers_ in results:
|
47 |
wers.extend(wers_)
|
48 |
|
49 |
-
wer = round(np.mean(wers)*100, 3)
|
50 |
print(f"\nTotal {len(wers)} samples")
|
51 |
print(f"WER : {wer}%")
|
52 |
|
@@ -62,6 +64,6 @@ if eval_task == "sim":
|
|
62 |
for sim_ in results:
|
63 |
sim_list.extend(sim_)
|
64 |
|
65 |
-
sim = round(sum(sim_list)/len(sim_list), 3)
|
66 |
print(f"\nTotal {len(sim_list)} samples")
|
67 |
print(f"SIM : {sim}")
|
|
|
1 |
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
2 |
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
sys.path.append(os.getcwd())
|
7 |
|
8 |
import multiprocessing as mp
|
|
|
21 |
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
22 |
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
|
23 |
|
24 |
+
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
|
25 |
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
26 |
|
27 |
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
|
|
48 |
for wers_ in results:
|
49 |
wers.extend(wers_)
|
50 |
|
51 |
+
wer = round(np.mean(wers) * 100, 3)
|
52 |
print(f"\nTotal {len(wers)} samples")
|
53 |
print(f"WER : {wer}%")
|
54 |
|
|
|
64 |
for sim_ in results:
|
65 |
sim_list.extend(sim_)
|
66 |
|
67 |
+
sim = round(sum(sim_list) / len(sim_list), 3)
|
68 |
print(f"\nTotal {len(sim_list)} samples")
|
69 |
print(f"SIM : {sim}")
|
scripts/eval_seedtts_testset.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
# Evaluate with Seed-TTS testset
|
2 |
|
3 |
-
import sys
|
|
|
|
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
import multiprocessing as mp
|
@@ -14,21 +16,21 @@ from model.utils import (
|
|
14 |
|
15 |
|
16 |
eval_task = "wer" # sim | wer
|
17 |
-
lang = "zh"
|
18 |
metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
|
19 |
# gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
|
20 |
-
gen_wav_dir =
|
21 |
|
22 |
|
23 |
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
24 |
-
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
25 |
-
gpus = [0,1,2,3,4,5,6,7]
|
26 |
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
27 |
|
28 |
local = False
|
29 |
if local: # use local custom checkpoint dir
|
30 |
if lang == "zh":
|
31 |
-
asr_ckpt_dir = "../checkpoints/funasr"
|
32 |
elif lang == "en":
|
33 |
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
34 |
else:
|
@@ -48,7 +50,7 @@ if eval_task == "wer":
|
|
48 |
for wers_ in results:
|
49 |
wers.extend(wers_)
|
50 |
|
51 |
-
wer = round(np.mean(wers)*100, 3)
|
52 |
print(f"\nTotal {len(wers)} samples")
|
53 |
print(f"WER : {wer}%")
|
54 |
|
@@ -64,6 +66,6 @@ if eval_task == "sim":
|
|
64 |
for sim_ in results:
|
65 |
sim_list.extend(sim_)
|
66 |
|
67 |
-
sim = round(sum(sim_list)/len(sim_list), 3)
|
68 |
print(f"\nTotal {len(sim_list)} samples")
|
69 |
print(f"SIM : {sim}")
|
|
|
1 |
# Evaluate with Seed-TTS testset
|
2 |
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
sys.path.append(os.getcwd())
|
7 |
|
8 |
import multiprocessing as mp
|
|
|
16 |
|
17 |
|
18 |
eval_task = "wer" # sim | wer
|
19 |
+
lang = "zh" # zh | en
|
20 |
metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
|
21 |
# gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
|
22 |
+
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
|
23 |
|
24 |
|
25 |
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
26 |
+
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
27 |
+
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
|
28 |
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
29 |
|
30 |
local = False
|
31 |
if local: # use local custom checkpoint dir
|
32 |
if lang == "zh":
|
33 |
+
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
34 |
elif lang == "en":
|
35 |
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
36 |
else:
|
|
|
50 |
for wers_ in results:
|
51 |
wers.extend(wers_)
|
52 |
|
53 |
+
wer = round(np.mean(wers) * 100, 3)
|
54 |
print(f"\nTotal {len(wers)} samples")
|
55 |
print(f"WER : {wer}%")
|
56 |
|
|
|
66 |
for sim_ in results:
|
67 |
sim_list.extend(sim_)
|
68 |
|
69 |
+
sim = round(sum(sim_list) / len(sim_list), 3)
|
70 |
print(f"\nTotal {len(sim_list)} samples")
|
71 |
print(f"SIM : {sim}")
|
scripts/prepare_csv_wavs.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
import sys
|
|
|
|
|
2 |
sys.path.append(os.getcwd())
|
3 |
|
4 |
from pathlib import Path
|
@@ -17,10 +19,11 @@ from model.utils import (
|
|
17 |
|
18 |
PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
19 |
|
|
|
20 |
def is_csv_wavs_format(input_dataset_dir):
|
21 |
fpath = Path(input_dataset_dir)
|
22 |
metadata = fpath / "metadata.csv"
|
23 |
-
wavs = fpath /
|
24 |
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
25 |
|
26 |
|
@@ -46,22 +49,24 @@ def prepare_csv_wavs_dir(input_dir):
|
|
46 |
|
47 |
return sub_result, durations, vocab_set
|
48 |
|
|
|
49 |
def get_audio_duration(audio_path):
|
50 |
audio, sample_rate = torchaudio.load(audio_path)
|
51 |
num_channels = audio.shape[0]
|
52 |
return audio.shape[1] / (sample_rate * num_channels)
|
53 |
|
|
|
54 |
def read_audio_text_pairs(csv_file_path):
|
55 |
audio_text_pairs = []
|
56 |
|
57 |
parent = Path(csv_file_path).parent
|
58 |
-
with open(csv_file_path, mode=
|
59 |
-
reader = csv.reader(csvfile, delimiter=
|
60 |
next(reader) # Skip the header row
|
61 |
for row in reader:
|
62 |
if len(row) >= 2:
|
63 |
audio_file = row[0].strip() # First column: audio file path
|
64 |
-
text = row[1].strip()
|
65 |
audio_file_path = parent / audio_file
|
66 |
audio_text_pairs.append((audio_file_path.as_posix(), text))
|
67 |
|
@@ -78,12 +83,12 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
|
78 |
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
79 |
raw_arrow_path = out_dir / "raw.arrow"
|
80 |
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
|
81 |
-
for line in tqdm(result, desc=
|
82 |
writer.write(line)
|
83 |
|
84 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
85 |
dur_json_path = out_dir / "duration.json"
|
86 |
-
with open(dur_json_path.as_posix(),
|
87 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
88 |
|
89 |
# vocab map, i.e. tokenizer
|
@@ -120,13 +125,14 @@ def cli():
|
|
120 |
# finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
|
121 |
# pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
|
122 |
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
|
123 |
-
parser.add_argument(
|
124 |
-
parser.add_argument(
|
125 |
-
parser.add_argument(
|
126 |
|
127 |
args = parser.parse_args()
|
128 |
|
129 |
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
|
130 |
|
|
|
131 |
if __name__ == "__main__":
|
132 |
cli()
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
from pathlib import Path
|
|
|
19 |
|
20 |
PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
21 |
|
22 |
+
|
23 |
def is_csv_wavs_format(input_dataset_dir):
|
24 |
fpath = Path(input_dataset_dir)
|
25 |
metadata = fpath / "metadata.csv"
|
26 |
+
wavs = fpath / "wavs"
|
27 |
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
28 |
|
29 |
|
|
|
49 |
|
50 |
return sub_result, durations, vocab_set
|
51 |
|
52 |
+
|
53 |
def get_audio_duration(audio_path):
|
54 |
audio, sample_rate = torchaudio.load(audio_path)
|
55 |
num_channels = audio.shape[0]
|
56 |
return audio.shape[1] / (sample_rate * num_channels)
|
57 |
|
58 |
+
|
59 |
def read_audio_text_pairs(csv_file_path):
|
60 |
audio_text_pairs = []
|
61 |
|
62 |
parent = Path(csv_file_path).parent
|
63 |
+
with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
|
64 |
+
reader = csv.reader(csvfile, delimiter="|")
|
65 |
next(reader) # Skip the header row
|
66 |
for row in reader:
|
67 |
if len(row) >= 2:
|
68 |
audio_file = row[0].strip() # First column: audio file path
|
69 |
+
text = row[1].strip() # Second column: text
|
70 |
audio_file_path = parent / audio_file
|
71 |
audio_text_pairs.append((audio_file_path.as_posix(), text))
|
72 |
|
|
|
83 |
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
84 |
raw_arrow_path = out_dir / "raw.arrow"
|
85 |
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
|
86 |
+
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
87 |
writer.write(line)
|
88 |
|
89 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
90 |
dur_json_path = out_dir / "duration.json"
|
91 |
+
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
|
92 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
93 |
|
94 |
# vocab map, i.e. tokenizer
|
|
|
125 |
# finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
|
126 |
# pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
|
127 |
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
|
128 |
+
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
|
129 |
+
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
|
130 |
+
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
|
131 |
|
132 |
args = parser.parse_args()
|
133 |
|
134 |
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
|
135 |
|
136 |
+
|
137 |
if __name__ == "__main__":
|
138 |
cli()
|
scripts/prepare_emilia.py
CHANGED
@@ -4,7 +4,9 @@
|
|
4 |
# generate audio text map for Emilia ZH & EN
|
5 |
# evaluate for vocab size
|
6 |
|
7 |
-
import sys
|
|
|
|
|
8 |
sys.path.append(os.getcwd())
|
9 |
|
10 |
from pathlib import Path
|
@@ -12,7 +14,6 @@ import json
|
|
12 |
from tqdm import tqdm
|
13 |
from concurrent.futures import ProcessPoolExecutor
|
14 |
|
15 |
-
from datasets import Dataset
|
16 |
from datasets.arrow_writer import ArrowWriter
|
17 |
|
18 |
from model.utils import (
|
@@ -21,13 +22,89 @@ from model.utils import (
|
|
21 |
)
|
22 |
|
23 |
|
24 |
-
out_zh = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
zh_filters = ["い", "て"]
|
26 |
# seems synthesized audios, or heavily code-switched
|
27 |
out_en = {
|
28 |
-
"EN_B00013_S00913",
|
29 |
-
|
30 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
}
|
32 |
en_filters = ["ا", "い", "て"]
|
33 |
|
@@ -43,18 +120,24 @@ def deal_with_audio_dir(audio_dir):
|
|
43 |
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
|
44 |
obj = json.loads(line)
|
45 |
text = obj["text"]
|
46 |
-
if obj[
|
47 |
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
|
48 |
bad_case_zh += 1
|
49 |
continue
|
50 |
else:
|
51 |
-
text = text.translate(
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
bad_case_en += 1
|
55 |
continue
|
56 |
if tokenizer == "pinyin":
|
57 |
-
text = convert_char_to_pinyin([text], polyphone
|
58 |
duration = obj["duration"]
|
59 |
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
|
60 |
durations.append(duration)
|
@@ -96,11 +179,11 @@ def main():
|
|
96 |
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
97 |
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
98 |
with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
|
99 |
-
for line in tqdm(result, desc=
|
100 |
writer.write(line)
|
101 |
|
102 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
103 |
-
with open(f"data/{dataset_name}/duration.json",
|
104 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
105 |
|
106 |
# vocab map, i.e. tokenizer
|
@@ -114,12 +197,13 @@ def main():
|
|
114 |
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
115 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
116 |
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
117 |
-
if "ZH" in langs:
|
118 |
-
|
|
|
|
|
119 |
|
120 |
|
121 |
if __name__ == "__main__":
|
122 |
-
|
123 |
max_workers = 32
|
124 |
|
125 |
tokenizer = "pinyin" # "pinyin" | "char"
|
|
|
4 |
# generate audio text map for Emilia ZH & EN
|
5 |
# evaluate for vocab size
|
6 |
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
|
10 |
sys.path.append(os.getcwd())
|
11 |
|
12 |
from pathlib import Path
|
|
|
14 |
from tqdm import tqdm
|
15 |
from concurrent.futures import ProcessPoolExecutor
|
16 |
|
|
|
17 |
from datasets.arrow_writer import ArrowWriter
|
18 |
|
19 |
from model.utils import (
|
|
|
22 |
)
|
23 |
|
24 |
|
25 |
+
out_zh = {
|
26 |
+
"ZH_B00041_S06226",
|
27 |
+
"ZH_B00042_S09204",
|
28 |
+
"ZH_B00065_S09430",
|
29 |
+
"ZH_B00065_S09431",
|
30 |
+
"ZH_B00066_S09327",
|
31 |
+
"ZH_B00066_S09328",
|
32 |
+
}
|
33 |
zh_filters = ["い", "て"]
|
34 |
# seems synthesized audios, or heavily code-switched
|
35 |
out_en = {
|
36 |
+
"EN_B00013_S00913",
|
37 |
+
"EN_B00042_S00120",
|
38 |
+
"EN_B00055_S04111",
|
39 |
+
"EN_B00061_S00693",
|
40 |
+
"EN_B00061_S01494",
|
41 |
+
"EN_B00061_S03375",
|
42 |
+
"EN_B00059_S00092",
|
43 |
+
"EN_B00111_S04300",
|
44 |
+
"EN_B00100_S03759",
|
45 |
+
"EN_B00087_S03811",
|
46 |
+
"EN_B00059_S00950",
|
47 |
+
"EN_B00089_S00946",
|
48 |
+
"EN_B00078_S05127",
|
49 |
+
"EN_B00070_S04089",
|
50 |
+
"EN_B00074_S09659",
|
51 |
+
"EN_B00061_S06983",
|
52 |
+
"EN_B00061_S07060",
|
53 |
+
"EN_B00059_S08397",
|
54 |
+
"EN_B00082_S06192",
|
55 |
+
"EN_B00091_S01238",
|
56 |
+
"EN_B00089_S07349",
|
57 |
+
"EN_B00070_S04343",
|
58 |
+
"EN_B00061_S02400",
|
59 |
+
"EN_B00076_S01262",
|
60 |
+
"EN_B00068_S06467",
|
61 |
+
"EN_B00076_S02943",
|
62 |
+
"EN_B00064_S05954",
|
63 |
+
"EN_B00061_S05386",
|
64 |
+
"EN_B00066_S06544",
|
65 |
+
"EN_B00076_S06944",
|
66 |
+
"EN_B00072_S08620",
|
67 |
+
"EN_B00076_S07135",
|
68 |
+
"EN_B00076_S09127",
|
69 |
+
"EN_B00065_S00497",
|
70 |
+
"EN_B00059_S06227",
|
71 |
+
"EN_B00063_S02859",
|
72 |
+
"EN_B00075_S01547",
|
73 |
+
"EN_B00061_S08286",
|
74 |
+
"EN_B00079_S02901",
|
75 |
+
"EN_B00092_S03643",
|
76 |
+
"EN_B00096_S08653",
|
77 |
+
"EN_B00063_S04297",
|
78 |
+
"EN_B00063_S04614",
|
79 |
+
"EN_B00079_S04698",
|
80 |
+
"EN_B00104_S01666",
|
81 |
+
"EN_B00061_S09504",
|
82 |
+
"EN_B00061_S09694",
|
83 |
+
"EN_B00065_S05444",
|
84 |
+
"EN_B00063_S06860",
|
85 |
+
"EN_B00065_S05725",
|
86 |
+
"EN_B00069_S07628",
|
87 |
+
"EN_B00083_S03875",
|
88 |
+
"EN_B00071_S07665",
|
89 |
+
"EN_B00071_S07665",
|
90 |
+
"EN_B00062_S04187",
|
91 |
+
"EN_B00065_S09873",
|
92 |
+
"EN_B00065_S09922",
|
93 |
+
"EN_B00084_S02463",
|
94 |
+
"EN_B00067_S05066",
|
95 |
+
"EN_B00106_S08060",
|
96 |
+
"EN_B00073_S06399",
|
97 |
+
"EN_B00073_S09236",
|
98 |
+
"EN_B00087_S00432",
|
99 |
+
"EN_B00085_S05618",
|
100 |
+
"EN_B00064_S01262",
|
101 |
+
"EN_B00072_S01739",
|
102 |
+
"EN_B00059_S03913",
|
103 |
+
"EN_B00069_S04036",
|
104 |
+
"EN_B00067_S05623",
|
105 |
+
"EN_B00060_S05389",
|
106 |
+
"EN_B00060_S07290",
|
107 |
+
"EN_B00062_S08995",
|
108 |
}
|
109 |
en_filters = ["ا", "い", "て"]
|
110 |
|
|
|
120 |
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
|
121 |
obj = json.loads(line)
|
122 |
text = obj["text"]
|
123 |
+
if obj["language"] == "zh":
|
124 |
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
|
125 |
bad_case_zh += 1
|
126 |
continue
|
127 |
else:
|
128 |
+
text = text.translate(
|
129 |
+
str.maketrans({",": ",", "!": "!", "?": "?"})
|
130 |
+
) # not "。" cuz much code-switched
|
131 |
+
if obj["language"] == "en":
|
132 |
+
if (
|
133 |
+
obj["wav"].split("/")[1] in out_en
|
134 |
+
or any(f in text for f in en_filters)
|
135 |
+
or repetition_found(text, length=4)
|
136 |
+
):
|
137 |
bad_case_en += 1
|
138 |
continue
|
139 |
if tokenizer == "pinyin":
|
140 |
+
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
|
141 |
duration = obj["duration"]
|
142 |
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
|
143 |
durations.append(duration)
|
|
|
179 |
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
180 |
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
181 |
with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
|
182 |
+
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
183 |
writer.write(line)
|
184 |
|
185 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
186 |
+
with open(f"data/{dataset_name}/duration.json", "w", encoding="utf-8") as f:
|
187 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
188 |
|
189 |
# vocab map, i.e. tokenizer
|
|
|
197 |
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
198 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
199 |
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
200 |
+
if "ZH" in langs:
|
201 |
+
print(f"Bad zh transcription case: {total_bad_case_zh}")
|
202 |
+
if "EN" in langs:
|
203 |
+
print(f"Bad en transcription case: {total_bad_case_en}\n")
|
204 |
|
205 |
|
206 |
if __name__ == "__main__":
|
|
|
207 |
max_workers = 32
|
208 |
|
209 |
tokenizer = "pinyin" # "pinyin" | "char"
|
scripts/prepare_wenetspeech4tts.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
# generate audio text map for WenetSpeech4TTS
|
2 |
# evaluate for vocab size
|
3 |
|
4 |
-
import sys
|
|
|
|
|
5 |
sys.path.append(os.getcwd())
|
6 |
|
7 |
import json
|
@@ -23,7 +25,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
|
|
23 |
|
24 |
audio_paths, texts, durations = [], [], []
|
25 |
for text_file in tqdm(text_files):
|
26 |
-
with open(os.path.join(text_dir, text_file),
|
27 |
first_line = file.readline().split("\t")
|
28 |
audio_nm = first_line[0]
|
29 |
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
|
@@ -32,7 +34,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
|
|
32 |
audio_paths.append(audio_path)
|
33 |
|
34 |
if tokenizer == "pinyin":
|
35 |
-
texts.extend(convert_char_to_pinyin([text], polyphone
|
36 |
elif tokenizer == "char":
|
37 |
texts.append(text)
|
38 |
|
@@ -46,7 +48,7 @@ def main():
|
|
46 |
assert tokenizer in ["pinyin", "char"]
|
47 |
|
48 |
audio_path_list, text_list, duration_list = [], [], []
|
49 |
-
|
50 |
executor = ProcessPoolExecutor(max_workers=max_workers)
|
51 |
futures = []
|
52 |
for dataset_path in dataset_paths:
|
@@ -68,8 +70,10 @@ def main():
|
|
68 |
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
|
69 |
dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
|
70 |
|
71 |
-
with open(f"data/{dataset_name}_{tokenizer}/duration.json",
|
72 |
-
json.dump(
|
|
|
|
|
73 |
|
74 |
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
|
75 |
text_vocab_set = set()
|
@@ -85,22 +89,21 @@ def main():
|
|
85 |
f.write(vocab + "\n")
|
86 |
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
|
87 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
|
88 |
-
|
89 |
|
90 |
-
if __name__ == "__main__":
|
91 |
|
|
|
92 |
max_workers = 32
|
93 |
|
94 |
tokenizer = "pinyin" # "pinyin" | "char"
|
95 |
polyphone = True
|
96 |
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
|
97 |
|
98 |
-
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
|
99 |
dataset_paths = [
|
100 |
"<SOME_PATH>/WenetSpeech4TTS/Basic",
|
101 |
"<SOME_PATH>/WenetSpeech4TTS/Standard",
|
102 |
"<SOME_PATH>/WenetSpeech4TTS/Premium",
|
103 |
-
|
104 |
print(f"\nChoose Dataset: {dataset_name}\n")
|
105 |
|
106 |
main()
|
@@ -109,8 +112,8 @@ if __name__ == "__main__":
|
|
109 |
# WenetSpeech4TTS Basic Standard Premium
|
110 |
# samples count 3932473 1941220 407494
|
111 |
# pinyin vocab size 1349 1348 1344 (no polyphone)
|
112 |
-
# - - 1459 (polyphone)
|
113 |
# char vocab size 5264 5219 5042
|
114 |
-
|
115 |
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
116 |
# please be careful if using pretrained model, make sure the vocab.txt is same
|
|
|
1 |
# generate audio text map for WenetSpeech4TTS
|
2 |
# evaluate for vocab size
|
3 |
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
|
7 |
sys.path.append(os.getcwd())
|
8 |
|
9 |
import json
|
|
|
25 |
|
26 |
audio_paths, texts, durations = [], [], []
|
27 |
for text_file in tqdm(text_files):
|
28 |
+
with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
|
29 |
first_line = file.readline().split("\t")
|
30 |
audio_nm = first_line[0]
|
31 |
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
|
|
|
34 |
audio_paths.append(audio_path)
|
35 |
|
36 |
if tokenizer == "pinyin":
|
37 |
+
texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
|
38 |
elif tokenizer == "char":
|
39 |
texts.append(text)
|
40 |
|
|
|
48 |
assert tokenizer in ["pinyin", "char"]
|
49 |
|
50 |
audio_path_list, text_list, duration_list = [], [], []
|
51 |
+
|
52 |
executor = ProcessPoolExecutor(max_workers=max_workers)
|
53 |
futures = []
|
54 |
for dataset_path in dataset_paths:
|
|
|
70 |
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
|
71 |
dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
|
72 |
|
73 |
+
with open(f"data/{dataset_name}_{tokenizer}/duration.json", "w", encoding="utf-8") as f:
|
74 |
+
json.dump(
|
75 |
+
{"duration": duration_list}, f, ensure_ascii=False
|
76 |
+
) # dup a json separately saving duration in case for DynamicBatchSampler ease
|
77 |
|
78 |
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
|
79 |
text_vocab_set = set()
|
|
|
89 |
f.write(vocab + "\n")
|
90 |
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
|
91 |
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
|
|
|
92 |
|
|
|
93 |
|
94 |
+
if __name__ == "__main__":
|
95 |
max_workers = 32
|
96 |
|
97 |
tokenizer = "pinyin" # "pinyin" | "char"
|
98 |
polyphone = True
|
99 |
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
|
100 |
|
101 |
+
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
|
102 |
dataset_paths = [
|
103 |
"<SOME_PATH>/WenetSpeech4TTS/Basic",
|
104 |
"<SOME_PATH>/WenetSpeech4TTS/Standard",
|
105 |
"<SOME_PATH>/WenetSpeech4TTS/Premium",
|
106 |
+
][-dataset_choice:]
|
107 |
print(f"\nChoose Dataset: {dataset_name}\n")
|
108 |
|
109 |
main()
|
|
|
112 |
# WenetSpeech4TTS Basic Standard Premium
|
113 |
# samples count 3932473 1941220 407494
|
114 |
# pinyin vocab size 1349 1348 1344 (no polyphone)
|
115 |
+
# - - 1459 (polyphone)
|
116 |
# char vocab size 5264 5219 5042
|
117 |
+
|
118 |
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
119 |
# please be careful if using pretrained model, make sure the vocab.txt is same
|
speech_edit.py
CHANGED
@@ -5,11 +5,11 @@ import torch.nn.functional as F
|
|
5 |
import torchaudio
|
6 |
from vocos import Vocos
|
7 |
|
8 |
-
from model import CFM, UNetT, DiT
|
9 |
from model.utils import (
|
10 |
load_checkpoint,
|
11 |
-
get_tokenizer,
|
12 |
-
convert_char_to_pinyin,
|
13 |
save_spectrogram,
|
14 |
)
|
15 |
|
@@ -35,18 +35,18 @@ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
|
35 |
ckpt_step = 1200000
|
36 |
|
37 |
nfe_step = 32 # 16, 32
|
38 |
-
cfg_strength = 2.
|
39 |
-
ode_method =
|
40 |
-
sway_sampling_coef = -1.
|
41 |
-
speed = 1.
|
42 |
|
43 |
if exp_name == "F5TTS_Base":
|
44 |
model_cls = DiT
|
45 |
-
model_cfg = dict(dim
|
46 |
|
47 |
elif exp_name == "E2TTS_Base":
|
48 |
model_cls = UNetT
|
49 |
-
model_cfg = dict(dim
|
50 |
|
51 |
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
52 |
output_dir = "tests"
|
@@ -62,8 +62,14 @@ output_dir = "tests"
|
|
62 |
audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
|
63 |
origin_text = "Some call me nature, others call me mother nature."
|
64 |
target_text = "Some call me optimist, others call me realist."
|
65 |
-
parts_to_edit = [
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
# audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
|
69 |
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
|
@@ -86,7 +92,7 @@ if local:
|
|
86 |
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
87 |
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
|
88 |
vocos.load_state_dict(state_dict)
|
89 |
-
|
90 |
vocos.eval()
|
91 |
else:
|
92 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
@@ -96,23 +102,19 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
|
96 |
|
97 |
# Model
|
98 |
model = CFM(
|
99 |
-
transformer =
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
mel_spec_kwargs = dict(
|
105 |
-
target_sample_rate = target_sample_rate,
|
106 |
-
n_mel_channels = n_mel_channels,
|
107 |
-
hop_length = hop_length,
|
108 |
),
|
109 |
-
odeint_kwargs
|
110 |
-
method
|
111 |
),
|
112 |
-
vocab_char_map
|
113 |
).to(device)
|
114 |
|
115 |
-
model = load_checkpoint(model, ckpt_path, device, use_ema
|
116 |
|
117 |
# Audio
|
118 |
audio, sr = torchaudio.load(audio_to_edit)
|
@@ -132,14 +134,18 @@ for part in parts_to_edit:
|
|
132 |
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
|
133 |
part_dur = part_dur * target_sample_rate
|
134 |
start = start * target_sample_rate
|
135 |
-
audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim
|
136 |
-
edit_mask = torch.cat(
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
140 |
offset = end * target_sample_rate
|
141 |
# audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
|
142 |
-
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value
|
143 |
audio = audio.to(device)
|
144 |
edit_mask = edit_mask.to(device)
|
145 |
|
@@ -159,14 +165,14 @@ duration = audio.shape[-1] // hop_length
|
|
159 |
# Inference
|
160 |
with torch.inference_mode():
|
161 |
generated, trajectory = model.sample(
|
162 |
-
cond
|
163 |
-
text
|
164 |
-
duration
|
165 |
-
steps
|
166 |
-
cfg_strength
|
167 |
-
sway_sampling_coef
|
168 |
-
seed
|
169 |
-
edit_mask
|
170 |
)
|
171 |
print(f"Generated mel: {generated.shape}")
|
172 |
|
|
|
5 |
import torchaudio
|
6 |
from vocos import Vocos
|
7 |
|
8 |
+
from model import CFM, UNetT, DiT
|
9 |
from model.utils import (
|
10 |
load_checkpoint,
|
11 |
+
get_tokenizer,
|
12 |
+
convert_char_to_pinyin,
|
13 |
save_spectrogram,
|
14 |
)
|
15 |
|
|
|
35 |
ckpt_step = 1200000
|
36 |
|
37 |
nfe_step = 32 # 16, 32
|
38 |
+
cfg_strength = 2.0
|
39 |
+
ode_method = "euler" # euler | midpoint
|
40 |
+
sway_sampling_coef = -1.0
|
41 |
+
speed = 1.0
|
42 |
|
43 |
if exp_name == "F5TTS_Base":
|
44 |
model_cls = DiT
|
45 |
+
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
46 |
|
47 |
elif exp_name == "E2TTS_Base":
|
48 |
model_cls = UNetT
|
49 |
+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
50 |
|
51 |
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
52 |
output_dir = "tests"
|
|
|
62 |
audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
|
63 |
origin_text = "Some call me nature, others call me mother nature."
|
64 |
target_text = "Some call me optimist, others call me realist."
|
65 |
+
parts_to_edit = [
|
66 |
+
[1.42, 2.44],
|
67 |
+
[4.04, 4.9],
|
68 |
+
] # stard_ends of "nature" & "mother nature", in seconds
|
69 |
+
fix_duration = [
|
70 |
+
1.2,
|
71 |
+
1,
|
72 |
+
] # fix duration for "optimist" & "realist", in seconds
|
73 |
|
74 |
# audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
|
75 |
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
|
|
|
92 |
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
93 |
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
|
94 |
vocos.load_state_dict(state_dict)
|
95 |
+
|
96 |
vocos.eval()
|
97 |
else:
|
98 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
|
|
102 |
|
103 |
# Model
|
104 |
model = CFM(
|
105 |
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
106 |
+
mel_spec_kwargs=dict(
|
107 |
+
target_sample_rate=target_sample_rate,
|
108 |
+
n_mel_channels=n_mel_channels,
|
109 |
+
hop_length=hop_length,
|
|
|
|
|
|
|
|
|
110 |
),
|
111 |
+
odeint_kwargs=dict(
|
112 |
+
method=ode_method,
|
113 |
),
|
114 |
+
vocab_char_map=vocab_char_map,
|
115 |
).to(device)
|
116 |
|
117 |
+
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
|
118 |
|
119 |
# Audio
|
120 |
audio, sr = torchaudio.load(audio_to_edit)
|
|
|
134 |
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
|
135 |
part_dur = part_dur * target_sample_rate
|
136 |
start = start * target_sample_rate
|
137 |
+
audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
|
138 |
+
edit_mask = torch.cat(
|
139 |
+
(
|
140 |
+
edit_mask,
|
141 |
+
torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
|
142 |
+
torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
|
143 |
+
),
|
144 |
+
dim=-1,
|
145 |
+
)
|
146 |
offset = end * target_sample_rate
|
147 |
# audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
|
148 |
+
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
|
149 |
audio = audio.to(device)
|
150 |
edit_mask = edit_mask.to(device)
|
151 |
|
|
|
165 |
# Inference
|
166 |
with torch.inference_mode():
|
167 |
generated, trajectory = model.sample(
|
168 |
+
cond=audio,
|
169 |
+
text=final_text_list,
|
170 |
+
duration=duration,
|
171 |
+
steps=nfe_step,
|
172 |
+
cfg_strength=cfg_strength,
|
173 |
+
sway_sampling_coef=sway_sampling_coef,
|
174 |
+
seed=seed,
|
175 |
+
edit_mask=edit_mask,
|
176 |
)
|
177 |
print(f"Generated mel: {generated.shape}")
|
178 |
|
train.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from model import CFM, UNetT, DiT,
|
2 |
from model.utils import get_tokenizer
|
3 |
from model.dataset import load_dataset
|
4 |
|
@@ -9,8 +9,8 @@ target_sample_rate = 24000
|
|
9 |
n_mel_channels = 100
|
10 |
hop_length = 256
|
11 |
|
12 |
-
tokenizer = "pinyin"
|
13 |
-
tokenizer_path = None
|
14 |
dataset_name = "Emilia_ZH_EN"
|
15 |
|
16 |
# -------------------------- Training Settings -------------------------- #
|
@@ -23,7 +23,7 @@ batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
|
|
23 |
batch_size_type = "frame" # "frame" or "sample"
|
24 |
max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
25 |
grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
|
26 |
-
max_grad_norm = 1.
|
27 |
|
28 |
epochs = 11 # use linear decay, thus epochs control the slope
|
29 |
num_warmup_updates = 20000 # warmup steps
|
@@ -34,15 +34,16 @@ last_per_steps = 5000 # save last checkpoint per steps
|
|
34 |
if exp_name == "F5TTS_Base":
|
35 |
wandb_resume_id = None
|
36 |
model_cls = DiT
|
37 |
-
model_cfg = dict(dim
|
38 |
elif exp_name == "E2TTS_Base":
|
39 |
wandb_resume_id = None
|
40 |
model_cls = UNetT
|
41 |
-
model_cfg = dict(dim
|
42 |
|
43 |
|
44 |
# ----------------------------------------------------------------------- #
|
45 |
|
|
|
46 |
def main():
|
47 |
if tokenizer == "custom":
|
48 |
tokenizer_path = tokenizer_path
|
@@ -51,44 +52,41 @@ def main():
|
|
51 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
52 |
|
53 |
mel_spec_kwargs = dict(
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
model = CFM(
|
60 |
-
transformer =
|
61 |
-
|
62 |
-
|
63 |
-
mel_dim = n_mel_channels
|
64 |
-
),
|
65 |
-
mel_spec_kwargs = mel_spec_kwargs,
|
66 |
-
vocab_char_map = vocab_char_map,
|
67 |
)
|
68 |
|
69 |
trainer = Trainer(
|
70 |
model,
|
71 |
-
epochs,
|
72 |
learning_rate,
|
73 |
-
num_warmup_updates
|
74 |
-
save_per_updates
|
75 |
-
checkpoint_path
|
76 |
-
batch_size
|
77 |
-
batch_size_type
|
78 |
-
max_samples
|
79 |
-
grad_accumulation_steps
|
80 |
-
max_grad_norm
|
81 |
-
wandb_project
|
82 |
-
wandb_run_name
|
83 |
-
wandb_resume_id
|
84 |
-
last_per_steps
|
85 |
)
|
86 |
|
87 |
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
88 |
-
trainer.train(
|
89 |
-
|
90 |
-
|
|
|
91 |
|
92 |
|
93 |
-
if __name__ ==
|
94 |
main()
|
|
|
1 |
+
from model import CFM, UNetT, DiT, Trainer
|
2 |
from model.utils import get_tokenizer
|
3 |
from model.dataset import load_dataset
|
4 |
|
|
|
9 |
n_mel_channels = 100
|
10 |
hop_length = 256
|
11 |
|
12 |
+
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
13 |
+
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
14 |
dataset_name = "Emilia_ZH_EN"
|
15 |
|
16 |
# -------------------------- Training Settings -------------------------- #
|
|
|
23 |
batch_size_type = "frame" # "frame" or "sample"
|
24 |
max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
25 |
grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
|
26 |
+
max_grad_norm = 1.0
|
27 |
|
28 |
epochs = 11 # use linear decay, thus epochs control the slope
|
29 |
num_warmup_updates = 20000 # warmup steps
|
|
|
34 |
if exp_name == "F5TTS_Base":
|
35 |
wandb_resume_id = None
|
36 |
model_cls = DiT
|
37 |
+
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
38 |
elif exp_name == "E2TTS_Base":
|
39 |
wandb_resume_id = None
|
40 |
model_cls = UNetT
|
41 |
+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
42 |
|
43 |
|
44 |
# ----------------------------------------------------------------------- #
|
45 |
|
46 |
+
|
47 |
def main():
|
48 |
if tokenizer == "custom":
|
49 |
tokenizer_path = tokenizer_path
|
|
|
52 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
53 |
|
54 |
mel_spec_kwargs = dict(
|
55 |
+
target_sample_rate=target_sample_rate,
|
56 |
+
n_mel_channels=n_mel_channels,
|
57 |
+
hop_length=hop_length,
|
58 |
+
)
|
59 |
+
|
60 |
model = CFM(
|
61 |
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
62 |
+
mel_spec_kwargs=mel_spec_kwargs,
|
63 |
+
vocab_char_map=vocab_char_map,
|
|
|
|
|
|
|
|
|
64 |
)
|
65 |
|
66 |
trainer = Trainer(
|
67 |
model,
|
68 |
+
epochs,
|
69 |
learning_rate,
|
70 |
+
num_warmup_updates=num_warmup_updates,
|
71 |
+
save_per_updates=save_per_updates,
|
72 |
+
checkpoint_path=f"ckpts/{exp_name}",
|
73 |
+
batch_size=batch_size_per_gpu,
|
74 |
+
batch_size_type=batch_size_type,
|
75 |
+
max_samples=max_samples,
|
76 |
+
grad_accumulation_steps=grad_accumulation_steps,
|
77 |
+
max_grad_norm=max_grad_norm,
|
78 |
+
wandb_project="CFM-TTS",
|
79 |
+
wandb_run_name=exp_name,
|
80 |
+
wandb_resume_id=wandb_resume_id,
|
81 |
+
last_per_steps=last_per_steps,
|
82 |
)
|
83 |
|
84 |
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
85 |
+
trainer.train(
|
86 |
+
train_dataset,
|
87 |
+
resumable_with_seed=666, # seed for shuffling dataset
|
88 |
+
)
|
89 |
|
90 |
|
91 |
+
if __name__ == "__main__":
|
92 |
main()
|