attashe commited on
Commit
05ec035
Β·
1 Parent(s): 2db58c3

initial commit

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. app.py +870 -0
  3. infer_utils.py +543 -0
  4. model.py +285 -0
  5. model_modules.py +658 -0
  6. model_utils.py +187 -0
  7. requirements.txt +26 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .ipynb_checkpoints
app.py ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: E402
2
+ # Above allows ruff to ignore E402: module level import not at top of file
3
+ import os
4
+ os.system('git clone https://github.com/NVIDIA/BigVGAN.git')
5
+
6
+
7
+ import re
8
+ import tempfile
9
+ from collections import OrderedDict
10
+ from importlib.resources import files
11
+
12
+ import click
13
+ import gradio as gr
14
+ import numpy as np
15
+ import soundfile as sf
16
+ import torchaudio
17
+ from cached_path import cached_path
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+
20
+ try:
21
+ import spaces
22
+
23
+ USING_SPACES = True
24
+ except ImportError:
25
+ USING_SPACES = False
26
+
27
+
28
+ def gpu_decorator(func):
29
+ if USING_SPACES:
30
+ return spaces.GPU(func)
31
+ else:
32
+ return func
33
+
34
+
35
+ from model import DiT, UNetT
36
+ from infer_utils import (
37
+ load_vocoder,
38
+ load_model,
39
+ preprocess_ref_audio_text,
40
+ infer_process,
41
+ remove_silence_for_generated_wav,
42
+ save_spectrogram,
43
+ )
44
+
45
+
46
+ DEFAULT_TTS_MODEL = "F5-TTS"
47
+ tts_model_choice = DEFAULT_TTS_MODEL
48
+
49
+
50
+ # load models
51
+
52
+ from huggingface_hub import hf_hub_download
53
+ import joblib
54
+
55
+ model_file = joblib.load(
56
+ hf_hub_download(repo_id="attashe/F5-TTS-Ru-finetune", filename="model_last_bigvgan.safetensors")
57
+ )
58
+
59
+ vocab_file = joblib.load(
60
+ hf_hub_download(repo_id="attashe/F5-TTS-Ru-finetune", filename="vocab.txt")
61
+ )
62
+ print(f"Using model file: {model_file} and vocab: {vocab_file}")
63
+
64
+ vocoder = load_vocoder(vocoder_name="bigvgan")
65
+
66
+
67
+ def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
68
+ ckpt_path = model_file
69
+ vocab_path = vocab_file
70
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
71
+ return load_model(DiT, F5TTS_model_cfg, ckpt_path, vocab_file=vocab_path)
72
+
73
+
74
+ def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))):
75
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
76
+ return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
77
+
78
+
79
+ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
80
+ ckpt_path, vocab_path = ckpt_path.strip(), vocab_path.strip()
81
+ if ckpt_path.startswith("hf://"):
82
+ ckpt_path = str(cached_path(ckpt_path))
83
+ if vocab_path.startswith("hf://"):
84
+ vocab_path = str(cached_path(vocab_path))
85
+ if model_cfg is None:
86
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
87
+ return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
88
+
89
+
90
+ F5TTS_ema_model = load_f5tts()
91
+ E2TTS_ema_model = load_e2tts() if USING_SPACES else None
92
+ custom_ema_model, pre_custom_path = None, ""
93
+
94
+ chat_model_state = None
95
+ chat_tokenizer_state = None
96
+
97
+
98
+ @gpu_decorator
99
+ def generate_response(messages, model, tokenizer):
100
+ """Generate response using Qwen"""
101
+ text = tokenizer.apply_chat_template(
102
+ messages,
103
+ tokenize=False,
104
+ add_generation_prompt=True,
105
+ )
106
+
107
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
108
+ generated_ids = model.generate(
109
+ **model_inputs,
110
+ max_new_tokens=512,
111
+ temperature=0.7,
112
+ top_p=0.95,
113
+ )
114
+
115
+ generated_ids = [
116
+ output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
117
+ ]
118
+ return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
119
+
120
+
121
+ @gpu_decorator
122
+ def infer(
123
+ ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1, show_info=gr.Info
124
+ ):
125
+ ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
126
+ gen_text = gen_text.lower().strip()
127
+ ref_text = ref_text.lower().strip()
128
+
129
+ if model == "F5-TTS":
130
+ ema_model = F5TTS_ema_model
131
+ elif model == "E2-TTS":
132
+ global E2TTS_ema_model
133
+ if E2TTS_ema_model is None:
134
+ show_info("Loading E2-TTS model...")
135
+ E2TTS_ema_model = load_e2tts()
136
+ ema_model = E2TTS_ema_model
137
+ elif isinstance(model, list) and model[0] == "Custom":
138
+ assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
139
+ global custom_ema_model, pre_custom_path
140
+ if pre_custom_path != model[1]:
141
+ show_info("Loading Custom TTS model...")
142
+ custom_ema_model = load_custom(model[1], vocab_path=model[2])
143
+ pre_custom_path = model[1]
144
+ ema_model = custom_ema_model
145
+
146
+ final_wave, final_sample_rate, combined_spectrogram = infer_process(
147
+ ref_audio,
148
+ ref_text,
149
+ gen_text,
150
+ ema_model,
151
+ vocoder,
152
+ cross_fade_duration=cross_fade_duration,
153
+ speed=speed,
154
+ show_info=show_info,
155
+ progress=gr.Progress(),
156
+ )
157
+
158
+ # Remove silence
159
+ if remove_silence:
160
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
161
+ sf.write(f.name, final_wave, final_sample_rate)
162
+ remove_silence_for_generated_wav(f.name)
163
+ final_wave, _ = torchaudio.load(f.name)
164
+ final_wave = final_wave.squeeze().cpu().numpy()
165
+
166
+ # Save the spectrogram
167
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
168
+ spectrogram_path = tmp_spectrogram.name
169
+ save_spectrogram(combined_spectrogram, spectrogram_path)
170
+
171
+ return (final_sample_rate, final_wave), spectrogram_path, ref_text
172
+
173
+
174
+ with gr.Blocks() as app_credits:
175
+ gr.Markdown("""
176
+ # Credits
177
+
178
+ * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
179
+ * [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
180
+ * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
181
+ """)
182
+ with gr.Blocks() as app_tts:
183
+ gr.Markdown("# Batched TTS")
184
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
185
+ gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
186
+ generate_btn = gr.Button("Synthesize", variant="primary")
187
+ with gr.Accordion("Advanced Settings", open=False):
188
+ ref_text_input = gr.Textbox(
189
+ label="Reference Text",
190
+ info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
191
+ lines=2,
192
+ )
193
+ remove_silence = gr.Checkbox(
194
+ label="Remove Silences",
195
+ info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
196
+ value=False,
197
+ )
198
+ speed_slider = gr.Slider(
199
+ label="Speed",
200
+ minimum=0.3,
201
+ maximum=2.0,
202
+ value=1.0,
203
+ step=0.1,
204
+ info="Adjust the speed of the audio.",
205
+ )
206
+ cross_fade_duration_slider = gr.Slider(
207
+ label="Cross-Fade Duration (s)",
208
+ minimum=0.0,
209
+ maximum=1.0,
210
+ value=0.15,
211
+ step=0.01,
212
+ info="Set the duration of the cross-fade between audio clips.",
213
+ )
214
+
215
+ audio_output = gr.Audio(label="Synthesized Audio")
216
+ spectrogram_output = gr.Image(label="Spectrogram")
217
+
218
+ @gpu_decorator
219
+ def basic_tts(
220
+ ref_audio_input,
221
+ ref_text_input,
222
+ gen_text_input,
223
+ remove_silence,
224
+ cross_fade_duration_slider,
225
+ speed_slider,
226
+ ):
227
+ audio_out, spectrogram_path, ref_text_out = infer(
228
+ ref_audio_input,
229
+ ref_text_input,
230
+ gen_text_input,
231
+ tts_model_choice,
232
+ remove_silence,
233
+ cross_fade_duration_slider,
234
+ speed_slider,
235
+ )
236
+ return audio_out, spectrogram_path, gr.update(value=ref_text_out)
237
+
238
+ generate_btn.click(
239
+ basic_tts,
240
+ inputs=[
241
+ ref_audio_input,
242
+ ref_text_input,
243
+ gen_text_input,
244
+ remove_silence,
245
+ cross_fade_duration_slider,
246
+ speed_slider,
247
+ ],
248
+ outputs=[audio_output, spectrogram_output, ref_text_input],
249
+ )
250
+
251
+
252
+ def parse_speechtypes_text(gen_text):
253
+ # Pattern to find {speechtype}
254
+ pattern = r"\{(.*?)\}"
255
+
256
+ # Split the text by the pattern
257
+ tokens = re.split(pattern, gen_text)
258
+
259
+ segments = []
260
+
261
+ current_style = "Regular"
262
+
263
+ for i in range(len(tokens)):
264
+ if i % 2 == 0:
265
+ # This is text
266
+ text = tokens[i].strip()
267
+ if text:
268
+ segments.append({"style": current_style, "text": text})
269
+ else:
270
+ # This is style
271
+ style = tokens[i].strip()
272
+ current_style = style
273
+
274
+ return segments
275
+
276
+
277
+ with gr.Blocks() as app_multistyle:
278
+ # New section for multistyle generation
279
+ gr.Markdown(
280
+ """
281
+ # Multiple Speech-Type Generation
282
+
283
+ This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, and the system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
284
+ """
285
+ )
286
+
287
+ with gr.Row():
288
+ gr.Markdown(
289
+ """
290
+ **Example Input:**
291
+ {Regular} Hello, I'd like to order a sandwich please.
292
+ {Surprised} What do you mean you're out of bread?
293
+ {Sad} I really wanted a sandwich though...
294
+ {Angry} You know what, darn you and your little shop!
295
+ {Whisper} I'll just go back home and cry now.
296
+ {Shouting} Why me?!
297
+ """
298
+ )
299
+
300
+ gr.Markdown(
301
+ """
302
+ **Example Input 2:**
303
+ {Speaker1_Happy} Hello, I'd like to order a sandwich please.
304
+ {Speaker2_Regular} Sorry, we're out of bread.
305
+ {Speaker1_Sad} I really wanted a sandwich though...
306
+ {Speaker2_Whisper} I'll give you the last one I was hiding.
307
+ """
308
+ )
309
+
310
+ gr.Markdown(
311
+ "Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
312
+ )
313
+
314
+ # Regular speech type (mandatory)
315
+ with gr.Row():
316
+ with gr.Column():
317
+ regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
318
+ regular_insert = gr.Button("Insert Label", variant="secondary")
319
+ regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
320
+ regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
321
+
322
+ # Regular speech type (max 100)
323
+ max_speech_types = 100
324
+ speech_type_rows = [] # 99
325
+ speech_type_names = [regular_name] # 100
326
+ speech_type_audios = [regular_audio] # 100
327
+ speech_type_ref_texts = [regular_ref_text] # 100
328
+ speech_type_delete_btns = [] # 99
329
+ speech_type_insert_btns = [regular_insert] # 100
330
+
331
+ # Additional speech types (99 more)
332
+ for i in range(max_speech_types - 1):
333
+ with gr.Row(visible=False) as row:
334
+ with gr.Column():
335
+ name_input = gr.Textbox(label="Speech Type Name")
336
+ delete_btn = gr.Button("Delete Type", variant="secondary")
337
+ insert_btn = gr.Button("Insert Label", variant="secondary")
338
+ audio_input = gr.Audio(label="Reference Audio", type="filepath")
339
+ ref_text_input = gr.Textbox(label="Reference Text", lines=2)
340
+ speech_type_rows.append(row)
341
+ speech_type_names.append(name_input)
342
+ speech_type_audios.append(audio_input)
343
+ speech_type_ref_texts.append(ref_text_input)
344
+ speech_type_delete_btns.append(delete_btn)
345
+ speech_type_insert_btns.append(insert_btn)
346
+
347
+ # Button to add speech type
348
+ add_speech_type_btn = gr.Button("Add Speech Type")
349
+
350
+ # Keep track of current number of speech types
351
+ speech_type_count = gr.State(value=1)
352
+
353
+ # Function to add a speech type
354
+ def add_speech_type_fn(speech_type_count):
355
+ if speech_type_count < max_speech_types:
356
+ speech_type_count += 1
357
+ # Prepare updates for the rows
358
+ row_updates = []
359
+ for i in range(1, max_speech_types):
360
+ if i < speech_type_count:
361
+ row_updates.append(gr.update(visible=True))
362
+ else:
363
+ row_updates.append(gr.update())
364
+ else:
365
+ # Optionally, show a warning
366
+ row_updates = [gr.update() for _ in range(1, max_speech_types)]
367
+ return [speech_type_count] + row_updates
368
+
369
+ add_speech_type_btn.click(
370
+ add_speech_type_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows
371
+ )
372
+
373
+ # Function to delete a speech type
374
+ def make_delete_speech_type_fn(index):
375
+ def delete_speech_type_fn(speech_type_count):
376
+ # Prepare updates
377
+ row_updates = []
378
+
379
+ for i in range(1, max_speech_types):
380
+ if i == index:
381
+ row_updates.append(gr.update(visible=False))
382
+ else:
383
+ row_updates.append(gr.update())
384
+
385
+ speech_type_count = max(1, speech_type_count)
386
+
387
+ return [speech_type_count] + row_updates
388
+
389
+ return delete_speech_type_fn
390
+
391
+ # Update delete button clicks
392
+ for i, delete_btn in enumerate(speech_type_delete_btns):
393
+ delete_fn = make_delete_speech_type_fn(i)
394
+ delete_btn.click(delete_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows)
395
+
396
+ # Text input for the prompt
397
+ gen_text_input_multistyle = gr.Textbox(
398
+ label="Text to Generate",
399
+ lines=10,
400
+ placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
401
+ )
402
+
403
+ def make_insert_speech_type_fn(index):
404
+ def insert_speech_type_fn(current_text, speech_type_name):
405
+ current_text = current_text or ""
406
+ speech_type_name = speech_type_name or "None"
407
+ updated_text = current_text + f"{{{speech_type_name}}} "
408
+ return gr.update(value=updated_text)
409
+
410
+ return insert_speech_type_fn
411
+
412
+ for i, insert_btn in enumerate(speech_type_insert_btns):
413
+ insert_fn = make_insert_speech_type_fn(i)
414
+ insert_btn.click(
415
+ insert_fn,
416
+ inputs=[gen_text_input_multistyle, speech_type_names[i]],
417
+ outputs=gen_text_input_multistyle,
418
+ )
419
+
420
+ with gr.Accordion("Advanced Settings", open=False):
421
+ remove_silence_multistyle = gr.Checkbox(
422
+ label="Remove Silences",
423
+ value=True,
424
+ )
425
+
426
+ # Generate button
427
+ generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
428
+
429
+ # Output audio
430
+ audio_output_multistyle = gr.Audio(label="Synthesized Audio")
431
+
432
+ @gpu_decorator
433
+ def generate_multistyle_speech(
434
+ gen_text,
435
+ *args,
436
+ ):
437
+ speech_type_names_list = args[:max_speech_types]
438
+ speech_type_audios_list = args[max_speech_types : 2 * max_speech_types]
439
+ speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types]
440
+ remove_silence = args[3 * max_speech_types]
441
+ # Collect the speech types and their audios into a dict
442
+ speech_types = OrderedDict()
443
+
444
+ ref_text_idx = 0
445
+ for name_input, audio_input, ref_text_input in zip(
446
+ speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
447
+ ):
448
+ if name_input and audio_input:
449
+ speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
450
+ else:
451
+ speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
452
+ ref_text_idx += 1
453
+
454
+ # Parse the gen_text into segments
455
+ segments = parse_speechtypes_text(gen_text)
456
+
457
+ # For each segment, generate speech
458
+ generated_audio_segments = []
459
+ current_style = "Regular"
460
+
461
+ for segment in segments:
462
+ style = segment["style"]
463
+ text = segment["text"]
464
+
465
+ if style in speech_types:
466
+ current_style = style
467
+ else:
468
+ # If style not available, default to Regular
469
+ current_style = "Regular"
470
+
471
+ ref_audio = speech_types[current_style]["audio"]
472
+ ref_text = speech_types[current_style].get("ref_text", "")
473
+
474
+ # Generate speech for this segment
475
+ audio_out, _, ref_text_out = infer(
476
+ ref_audio, ref_text, text, tts_model_choice, remove_silence, 0, show_info=print
477
+ ) # show_info=print no pull to top when generating
478
+ sr, audio_data = audio_out
479
+
480
+ generated_audio_segments.append(audio_data)
481
+ speech_types[current_style]["ref_text"] = ref_text_out
482
+
483
+ # Concatenate all audio segments
484
+ if generated_audio_segments:
485
+ final_audio_data = np.concatenate(generated_audio_segments)
486
+ return [(sr, final_audio_data)] + [
487
+ gr.update(value=speech_types[style]["ref_text"]) for style in speech_types
488
+ ]
489
+ else:
490
+ gr.Warning("No audio generated.")
491
+ return [None] + [gr.update(value=speech_types[style]["ref_text"]) for style in speech_types]
492
+
493
+ generate_multistyle_btn.click(
494
+ generate_multistyle_speech,
495
+ inputs=[
496
+ gen_text_input_multistyle,
497
+ ]
498
+ + speech_type_names
499
+ + speech_type_audios
500
+ + speech_type_ref_texts
501
+ + [
502
+ remove_silence_multistyle,
503
+ ],
504
+ outputs=[audio_output_multistyle] + speech_type_ref_texts,
505
+ )
506
+
507
+ # Validation function to disable Generate button if speech types are missing
508
+ def validate_speech_types(gen_text, regular_name, *args):
509
+ speech_type_names_list = args[:max_speech_types]
510
+
511
+ # Collect the speech types names
512
+ speech_types_available = set()
513
+ if regular_name:
514
+ speech_types_available.add(regular_name)
515
+ for name_input in speech_type_names_list:
516
+ if name_input:
517
+ speech_types_available.add(name_input)
518
+
519
+ # Parse the gen_text to get the speech types used
520
+ segments = parse_speechtypes_text(gen_text)
521
+ speech_types_in_text = set(segment["style"] for segment in segments)
522
+
523
+ # Check if all speech types in text are available
524
+ missing_speech_types = speech_types_in_text - speech_types_available
525
+
526
+ if missing_speech_types:
527
+ # Disable the generate button
528
+ return gr.update(interactive=False)
529
+ else:
530
+ # Enable the generate button
531
+ return gr.update(interactive=True)
532
+
533
+ gen_text_input_multistyle.change(
534
+ validate_speech_types,
535
+ inputs=[gen_text_input_multistyle, regular_name] + speech_type_names,
536
+ outputs=generate_multistyle_btn,
537
+ )
538
+
539
+
540
+ with gr.Blocks() as app_chat:
541
+ gr.Markdown(
542
+ """
543
+ # Voice Chat
544
+ Have a conversation with an AI using your reference voice!
545
+ 1. Upload a reference audio clip and optionally its transcript.
546
+ 2. Load the chat model.
547
+ 3. Record your message through your microphone.
548
+ 4. The AI will respond using the reference voice.
549
+ """
550
+ )
551
+
552
+ if not USING_SPACES:
553
+ load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
554
+
555
+ chat_interface_container = gr.Column(visible=False)
556
+
557
+ @gpu_decorator
558
+ def load_chat_model():
559
+ global chat_model_state, chat_tokenizer_state
560
+ if chat_model_state is None:
561
+ show_info = gr.Info
562
+ show_info("Loading chat model...")
563
+ model_name = "Qwen/Qwen2.5-3B-Instruct"
564
+ chat_model_state = AutoModelForCausalLM.from_pretrained(
565
+ model_name, torch_dtype="auto", device_map="auto"
566
+ )
567
+ chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
568
+ show_info("Chat model loaded.")
569
+
570
+ return gr.update(visible=False), gr.update(visible=True)
571
+
572
+ load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
573
+
574
+ else:
575
+ chat_interface_container = gr.Column()
576
+
577
+ if chat_model_state is None:
578
+ model_name = "Qwen/Qwen2.5-3B-Instruct"
579
+ chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
580
+ chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
581
+
582
+ with chat_interface_container:
583
+ with gr.Row():
584
+ with gr.Column():
585
+ ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
586
+ with gr.Column():
587
+ with gr.Accordion("Advanced Settings", open=False):
588
+ remove_silence_chat = gr.Checkbox(
589
+ label="Remove Silences",
590
+ value=True,
591
+ )
592
+ ref_text_chat = gr.Textbox(
593
+ label="Reference Text",
594
+ info="Optional: Leave blank to auto-transcribe",
595
+ lines=2,
596
+ )
597
+ system_prompt_chat = gr.Textbox(
598
+ label="System Prompt",
599
+ value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
600
+ lines=2,
601
+ )
602
+
603
+ chatbot_interface = gr.Chatbot(label="Conversation")
604
+
605
+ with gr.Row():
606
+ with gr.Column():
607
+ audio_input_chat = gr.Microphone(
608
+ label="Speak your message",
609
+ type="filepath",
610
+ )
611
+ audio_output_chat = gr.Audio(autoplay=True)
612
+ with gr.Column():
613
+ text_input_chat = gr.Textbox(
614
+ label="Type your message",
615
+ lines=1,
616
+ )
617
+ send_btn_chat = gr.Button("Send Message")
618
+ clear_btn_chat = gr.Button("Clear Conversation")
619
+
620
+ conversation_state = gr.State(
621
+ value=[
622
+ {
623
+ "role": "system",
624
+ "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
625
+ }
626
+ ]
627
+ )
628
+
629
+ # Modify process_audio_input to use model and tokenizer from state
630
+ @gpu_decorator
631
+ def process_audio_input(audio_path, text, history, conv_state):
632
+ """Handle audio or text input from user"""
633
+
634
+ if not audio_path and not text.strip():
635
+ return history, conv_state, ""
636
+
637
+ if audio_path:
638
+ text = preprocess_ref_audio_text(audio_path, text)[1]
639
+
640
+ if not text.strip():
641
+ return history, conv_state, ""
642
+
643
+ conv_state.append({"role": "user", "content": text})
644
+ history.append((text, None))
645
+
646
+ response = generate_response(conv_state, chat_model_state, chat_tokenizer_state)
647
+
648
+ conv_state.append({"role": "assistant", "content": response})
649
+ history[-1] = (text, response)
650
+
651
+ return history, conv_state, ""
652
+
653
+ @gpu_decorator
654
+ def generate_audio_response(history, ref_audio, ref_text, remove_silence):
655
+ """Generate TTS audio for AI response"""
656
+ if not history or not ref_audio:
657
+ return None
658
+
659
+ last_user_message, last_ai_response = history[-1]
660
+ if not last_ai_response:
661
+ return None
662
+
663
+ audio_result, _, ref_text_out = infer(
664
+ ref_audio,
665
+ ref_text,
666
+ last_ai_response,
667
+ tts_model_choice,
668
+ remove_silence,
669
+ cross_fade_duration=0.15,
670
+ speed=1.0,
671
+ show_info=print, # show_info=print no pull to top when generating
672
+ )
673
+ return audio_result, gr.update(value=ref_text_out)
674
+
675
+ def clear_conversation():
676
+ """Reset the conversation"""
677
+ return [], [
678
+ {
679
+ "role": "system",
680
+ "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
681
+ }
682
+ ]
683
+
684
+ def update_system_prompt(new_prompt):
685
+ """Update the system prompt and reset the conversation"""
686
+ new_conv_state = [{"role": "system", "content": new_prompt}]
687
+ return [], new_conv_state
688
+
689
+ # Handle audio input
690
+ audio_input_chat.stop_recording(
691
+ process_audio_input,
692
+ inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
693
+ outputs=[chatbot_interface, conversation_state],
694
+ ).then(
695
+ generate_audio_response,
696
+ inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
697
+ outputs=[audio_output_chat, ref_text_chat],
698
+ ).then(
699
+ lambda: None,
700
+ None,
701
+ audio_input_chat,
702
+ )
703
+
704
+ # Handle text input
705
+ text_input_chat.submit(
706
+ process_audio_input,
707
+ inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
708
+ outputs=[chatbot_interface, conversation_state],
709
+ ).then(
710
+ generate_audio_response,
711
+ inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
712
+ outputs=[audio_output_chat, ref_text_chat],
713
+ ).then(
714
+ lambda: None,
715
+ None,
716
+ text_input_chat,
717
+ )
718
+
719
+ # Handle send button
720
+ send_btn_chat.click(
721
+ process_audio_input,
722
+ inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
723
+ outputs=[chatbot_interface, conversation_state],
724
+ ).then(
725
+ generate_audio_response,
726
+ inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
727
+ outputs=[audio_output_chat, ref_text_chat],
728
+ ).then(
729
+ lambda: None,
730
+ None,
731
+ text_input_chat,
732
+ )
733
+
734
+ # Handle clear button
735
+ clear_btn_chat.click(
736
+ clear_conversation,
737
+ outputs=[chatbot_interface, conversation_state],
738
+ )
739
+
740
+ # Handle system prompt change and reset conversation
741
+ system_prompt_chat.change(
742
+ update_system_prompt,
743
+ inputs=system_prompt_chat,
744
+ outputs=[chatbot_interface, conversation_state],
745
+ )
746
+
747
+
748
+ with gr.Blocks() as app:
749
+ gr.Markdown(
750
+ """
751
+ # E2/F5 TTS
752
+
753
+ This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
754
+
755
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
756
+ * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
757
+
758
+ The checkpoints currently support English and Chinese.
759
+
760
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s with βœ‚ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
761
+
762
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
763
+ """
764
+ )
765
+
766
+ last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom.txt")
767
+
768
+ def load_last_used_custom():
769
+ try:
770
+ with open(last_used_custom, "r") as f:
771
+ return f.read().split(",")
772
+ except FileNotFoundError:
773
+ last_used_custom.parent.mkdir(parents=True, exist_ok=True)
774
+ return [
775
+ "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
776
+ "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
777
+ ]
778
+
779
+ def switch_tts_model(new_choice):
780
+ global tts_model_choice
781
+ if new_choice == "Custom": # override in case webpage is refreshed
782
+ custom_ckpt_path, custom_vocab_path = load_last_used_custom()
783
+ tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
784
+ return gr.update(visible=True, value=custom_ckpt_path), gr.update(visible=True, value=custom_vocab_path)
785
+ else:
786
+ tts_model_choice = new_choice
787
+ return gr.update(visible=False), gr.update(visible=False)
788
+
789
+ def set_custom_model(custom_ckpt_path, custom_vocab_path):
790
+ global tts_model_choice
791
+ tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
792
+ with open(last_used_custom, "w") as f:
793
+ f.write(f"{custom_ckpt_path},{custom_vocab_path}")
794
+
795
+ with gr.Row():
796
+ if not USING_SPACES:
797
+ choose_tts_model = gr.Radio(
798
+ choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
799
+ )
800
+ else:
801
+ choose_tts_model = gr.Radio(
802
+ choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
803
+ )
804
+ custom_ckpt_path = gr.Dropdown(
805
+ choices=["hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"],
806
+ value=load_last_used_custom()[0],
807
+ allow_custom_value=True,
808
+ label="MODEL CKPT: local_path | hf://user_id/repo_id/model_ckpt",
809
+ visible=False,
810
+ )
811
+ custom_vocab_path = gr.Dropdown(
812
+ choices=["hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt"],
813
+ value=load_last_used_custom()[1],
814
+ allow_custom_value=True,
815
+ label="VOCAB FILE: local_path | hf://user_id/repo_id/vocab_file",
816
+ visible=False,
817
+ )
818
+
819
+ choose_tts_model.change(
820
+ switch_tts_model,
821
+ inputs=[choose_tts_model],
822
+ outputs=[custom_ckpt_path, custom_vocab_path],
823
+ show_progress="hidden",
824
+ )
825
+ custom_ckpt_path.change(
826
+ set_custom_model,
827
+ inputs=[custom_ckpt_path, custom_vocab_path],
828
+ show_progress="hidden",
829
+ )
830
+ custom_vocab_path.change(
831
+ set_custom_model,
832
+ inputs=[custom_ckpt_path, custom_vocab_path],
833
+ show_progress="hidden",
834
+ )
835
+
836
+ gr.TabbedInterface(
837
+ [app_tts, app_multistyle, app_chat, app_credits],
838
+ ["Basic-TTS", "Multi-Speech", "Voice-Chat", "Credits"],
839
+ )
840
+
841
+
842
+ @click.command()
843
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
844
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
845
+ @click.option(
846
+ "--share",
847
+ "-s",
848
+ default=False,
849
+ is_flag=True,
850
+ help="Share the app via Gradio share link",
851
+ )
852
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
853
+ @click.option(
854
+ "--root_path",
855
+ "-r",
856
+ default=None,
857
+ type=str,
858
+ help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
859
+ )
860
+ def main(port, host, share, api, root_path):
861
+ global app
862
+ print("Starting app...")
863
+ app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api, root_path=root_path)
864
+
865
+
866
+ if __name__ == "__main__":
867
+ if not USING_SPACES:
868
+ main()
869
+ else:
870
+ app.queue().launch()
infer_utils.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A unified script for inference process
2
+ # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
+ import os
4
+ import sys
5
+
6
+ os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
7
+ sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
8
+
9
+ import hashlib
10
+ import re
11
+ import tempfile
12
+ from importlib.resources import files
13
+
14
+ import matplotlib
15
+
16
+ matplotlib.use("Agg")
17
+
18
+ import matplotlib.pylab as plt
19
+ import numpy as np
20
+ import torch
21
+ import torchaudio
22
+ import tqdm
23
+ from huggingface_hub import snapshot_download, hf_hub_download
24
+ from pydub import AudioSegment, silence
25
+ from transformers import pipeline
26
+ from vocos import Vocos
27
+
28
+ from model import CFM
29
+ from model_utils import (
30
+ get_tokenizer,
31
+ convert_char_to_pinyin,
32
+ )
33
+
34
+ _ref_audio_cache = {}
35
+
36
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
37
+
38
+ # -----------------------------------------
39
+
40
+ target_sample_rate = 24000
41
+ n_mel_channels = 100
42
+ hop_length = 256
43
+ win_length = 1024
44
+ n_fft = 1024
45
+ mel_spec_type = "bigvgan"
46
+ target_rms = 0.1
47
+ cross_fade_duration = 0.15
48
+ ode_method = "euler"
49
+ nfe_step = 32 # 16, 32
50
+ cfg_strength = 2.0
51
+ sway_sampling_coef = -1.0
52
+ speed = 1.0
53
+ fix_duration = None
54
+
55
+ # -----------------------------------------
56
+
57
+
58
+ # chunk text into smaller pieces
59
+
60
+
61
+ def chunk_text(text, max_chars=135):
62
+ """
63
+ Splits the input text into chunks, each with a maximum number of characters.
64
+
65
+ Args:
66
+ text (str): The text to be split.
67
+ max_chars (int): The maximum number of characters per chunk.
68
+
69
+ Returns:
70
+ List[str]: A list of text chunks.
71
+ """
72
+ chunks = []
73
+ current_chunk = ""
74
+ # Split the text into sentences based on punctuation followed by whitespace
75
+ sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[οΌ›οΌšοΌŒγ€‚οΌοΌŸ])", text)
76
+
77
+ for sentence in sentences:
78
+ if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
79
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
80
+ else:
81
+ if current_chunk:
82
+ chunks.append(current_chunk.strip())
83
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
84
+
85
+ if current_chunk:
86
+ chunks.append(current_chunk.strip())
87
+
88
+ return chunks
89
+
90
+
91
+ # load vocoder
92
+ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None):
93
+ if vocoder_name == "vocos":
94
+ # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
95
+ if is_local:
96
+ print(f"Load vocos from local path {local_path}")
97
+ config_path = f"{local_path}/config.yaml"
98
+ model_path = f"{local_path}/pytorch_model.bin"
99
+ else:
100
+ print("Download Vocos from huggingface charactr/vocos-mel-24khz")
101
+ repo_id = "charactr/vocos-mel-24khz"
102
+ config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
103
+ model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
104
+ vocoder = Vocos.from_hparams(config_path)
105
+ state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
106
+ from vocos.feature_extractors import EncodecFeatures
107
+
108
+ if isinstance(vocoder.feature_extractor, EncodecFeatures):
109
+ encodec_parameters = {
110
+ "feature_extractor.encodec." + key: value
111
+ for key, value in vocoder.feature_extractor.encodec.state_dict().items()
112
+ }
113
+ state_dict.update(encodec_parameters)
114
+ vocoder.load_state_dict(state_dict)
115
+ vocoder = vocoder.eval().to(device)
116
+ elif vocoder_name == "bigvgan":
117
+ try:
118
+ import sys
119
+ sys.path.append('BigVGAN')
120
+ import bigvgan
121
+ except ImportError:
122
+ print("You need to follow the README to init submodule and change the BigVGAN source code.")
123
+ if is_local:
124
+ """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
125
+ vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
126
+ else:
127
+ local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir)
128
+ vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
129
+
130
+ vocoder.remove_weight_norm()
131
+ vocoder = vocoder.eval().to(device)
132
+ return vocoder
133
+
134
+
135
+ # load asr pipeline
136
+
137
+ asr_pipe = None
138
+
139
+
140
+ def initialize_asr_pipeline(device: str = device, dtype=None):
141
+ if dtype is None:
142
+ dtype = (
143
+ torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
144
+ )
145
+ global asr_pipe
146
+ asr_pipe = pipeline(
147
+ "automatic-speech-recognition",
148
+ model="openai/whisper-large-v3-turbo",
149
+ torch_dtype=dtype,
150
+ device=device,
151
+ )
152
+
153
+
154
+ # transcribe
155
+
156
+
157
+ def transcribe(ref_audio, language=None):
158
+ global asr_pipe
159
+ if asr_pipe is None:
160
+ initialize_asr_pipeline(device=device)
161
+ return asr_pipe(
162
+ ref_audio,
163
+ chunk_length_s=30,
164
+ batch_size=128,
165
+ generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
166
+ return_timestamps=False,
167
+ )["text"].strip()
168
+
169
+
170
+ # load model checkpoint for inference
171
+
172
+
173
+ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
174
+ if dtype is None:
175
+ dtype = (
176
+ torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
177
+ )
178
+ model = model.to(dtype)
179
+
180
+ ckpt_type = ckpt_path.split(".")[-1]
181
+ if ckpt_type == "safetensors":
182
+ from safetensors.torch import load_file
183
+
184
+ checkpoint = load_file(ckpt_path, device=device)
185
+ else:
186
+ checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
187
+
188
+ if use_ema:
189
+ if ckpt_type == "safetensors":
190
+ checkpoint = {"ema_model_state_dict": checkpoint}
191
+ checkpoint["model_state_dict"] = {
192
+ k.replace("ema_model.", ""): v
193
+ for k, v in checkpoint["ema_model_state_dict"].items()
194
+ if k not in ["initted", "step"]
195
+ }
196
+
197
+ # patch for backward compatibility, 305e3ea
198
+ for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
199
+ if key in checkpoint["model_state_dict"]:
200
+ del checkpoint["model_state_dict"][key]
201
+
202
+ model.load_state_dict(checkpoint["model_state_dict"])
203
+ else:
204
+ if ckpt_type == "safetensors":
205
+ checkpoint = {"model_state_dict": checkpoint}
206
+ model.load_state_dict(checkpoint["model_state_dict"])
207
+
208
+ del checkpoint
209
+ torch.cuda.empty_cache()
210
+
211
+ return model.to(device)
212
+
213
+
214
+ # load model for inference
215
+
216
+
217
+ def load_model(
218
+ model_cls,
219
+ model_cfg,
220
+ ckpt_path,
221
+ mel_spec_type=mel_spec_type,
222
+ vocab_file="",
223
+ ode_method=ode_method,
224
+ use_ema=True,
225
+ device=device,
226
+ ):
227
+ if vocab_file == "":
228
+ vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
229
+ tokenizer = "custom"
230
+
231
+ print("\nvocab : ", vocab_file)
232
+ print("token : ", tokenizer)
233
+ print("model : ", ckpt_path, "\n")
234
+
235
+ vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
236
+ model = CFM(
237
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
238
+ mel_spec_kwargs=dict(
239
+ n_fft=n_fft,
240
+ hop_length=hop_length,
241
+ win_length=win_length,
242
+ n_mel_channels=n_mel_channels,
243
+ target_sample_rate=target_sample_rate,
244
+ mel_spec_type=mel_spec_type,
245
+ ),
246
+ odeint_kwargs=dict(
247
+ method=ode_method,
248
+ ),
249
+ vocab_char_map=vocab_char_map,
250
+ ).to(device)
251
+
252
+ dtype = torch.float32 if mel_spec_type == "bigvgan" else None
253
+ model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
254
+
255
+ return model
256
+
257
+
258
+ def remove_silence_edges(audio, silence_threshold=-42):
259
+ # Remove silence from the start
260
+ non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
261
+ audio = audio[non_silent_start_idx:]
262
+
263
+ # Remove silence from the end
264
+ non_silent_end_duration = audio.duration_seconds
265
+ for ms in reversed(audio):
266
+ if ms.dBFS > silence_threshold:
267
+ break
268
+ non_silent_end_duration -= 0.001
269
+ trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
270
+
271
+ return trimmed_audio
272
+
273
+
274
+ # preprocess reference audio and text
275
+
276
+
277
+ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
278
+ show_info("Converting audio...")
279
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
280
+ aseg = AudioSegment.from_file(ref_audio_orig)
281
+
282
+ if clip_short:
283
+ # 1. try to find long silence for clipping
284
+ non_silent_segs = silence.split_on_silence(
285
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
286
+ )
287
+ non_silent_wave = AudioSegment.silent(duration=0)
288
+ for non_silent_seg in non_silent_segs:
289
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
290
+ show_info("Audio is over 15s, clipping short. (1)")
291
+ break
292
+ non_silent_wave += non_silent_seg
293
+
294
+ # 2. try to find short silence for clipping if 1. failed
295
+ if len(non_silent_wave) > 15000:
296
+ non_silent_segs = silence.split_on_silence(
297
+ aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
298
+ )
299
+ non_silent_wave = AudioSegment.silent(duration=0)
300
+ for non_silent_seg in non_silent_segs:
301
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
302
+ show_info("Audio is over 15s, clipping short. (2)")
303
+ break
304
+ non_silent_wave += non_silent_seg
305
+
306
+ aseg = non_silent_wave
307
+
308
+ # 3. if no proper silence found for clipping
309
+ if len(aseg) > 15000:
310
+ aseg = aseg[:15000]
311
+ show_info("Audio is over 15s, clipping short. (3)")
312
+
313
+ aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
314
+ aseg.export(f.name, format="wav")
315
+ ref_audio = f.name
316
+
317
+ # Compute a hash of the reference audio file
318
+ with open(ref_audio, "rb") as audio_file:
319
+ audio_data = audio_file.read()
320
+ audio_hash = hashlib.md5(audio_data).hexdigest()
321
+
322
+ if not ref_text.strip():
323
+ global _ref_audio_cache
324
+ if audio_hash in _ref_audio_cache:
325
+ # Use cached asr transcription
326
+ show_info("Using cached reference text...")
327
+ ref_text = _ref_audio_cache[audio_hash]
328
+ else:
329
+ show_info("No reference text provided, transcribing reference audio...")
330
+ ref_text = transcribe(ref_audio)
331
+ # Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
332
+ _ref_audio_cache[audio_hash] = ref_text
333
+ else:
334
+ show_info("Using custom reference text...")
335
+
336
+ # Ensure ref_text ends with a proper sentence-ending punctuation
337
+ if not ref_text.endswith(". ") and not ref_text.endswith("。"):
338
+ if ref_text.endswith("."):
339
+ ref_text += " "
340
+ else:
341
+ ref_text += ". "
342
+
343
+ print("ref_text ", ref_text)
344
+
345
+ return ref_audio, ref_text
346
+
347
+
348
+ # infer process: chunk text -> infer batches [i.e. infer_batch_process()]
349
+
350
+
351
+ def infer_process(
352
+ ref_audio,
353
+ ref_text,
354
+ gen_text,
355
+ model_obj,
356
+ vocoder,
357
+ mel_spec_type=mel_spec_type,
358
+ show_info=print,
359
+ progress=tqdm,
360
+ target_rms=target_rms,
361
+ cross_fade_duration=cross_fade_duration,
362
+ nfe_step=nfe_step,
363
+ cfg_strength=cfg_strength,
364
+ sway_sampling_coef=sway_sampling_coef,
365
+ speed=speed,
366
+ fix_duration=fix_duration,
367
+ device=device,
368
+ ):
369
+ # Split the input text into batches
370
+ audio, sr = torchaudio.load(ref_audio)
371
+ max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
372
+ print(f'{max_chars=}')
373
+ max_chars = 300 # 135
374
+
375
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
376
+ for i, gen_text in enumerate(gen_text_batches):
377
+ print(f"gen_text {i}", gen_text)
378
+
379
+ show_info(f"Generating audio in {len(gen_text_batches)} batches...")
380
+ return infer_batch_process(
381
+ (audio, sr),
382
+ ref_text,
383
+ gen_text_batches,
384
+ model_obj,
385
+ vocoder,
386
+ mel_spec_type=mel_spec_type,
387
+ progress=progress,
388
+ target_rms=target_rms,
389
+ cross_fade_duration=cross_fade_duration,
390
+ nfe_step=nfe_step,
391
+ cfg_strength=cfg_strength,
392
+ sway_sampling_coef=sway_sampling_coef,
393
+ speed=speed,
394
+ fix_duration=fix_duration,
395
+ device=device,
396
+ )
397
+
398
+
399
+ # infer batches
400
+
401
+
402
+ def infer_batch_process(
403
+ ref_audio,
404
+ ref_text,
405
+ gen_text_batches,
406
+ model_obj,
407
+ vocoder,
408
+ mel_spec_type="vocos",
409
+ progress=tqdm,
410
+ target_rms=0.1,
411
+ cross_fade_duration=0.15,
412
+ nfe_step=32,
413
+ cfg_strength=2.0,
414
+ sway_sampling_coef=-1,
415
+ speed=1,
416
+ fix_duration=None,
417
+ device=None,
418
+ ):
419
+ audio, sr = ref_audio
420
+ if audio.shape[0] > 1:
421
+ audio = torch.mean(audio, dim=0, keepdim=True)
422
+
423
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
424
+ if rms < target_rms:
425
+ audio = audio * target_rms / rms
426
+ if sr != target_sample_rate:
427
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
428
+ audio = resampler(audio)
429
+ audio = audio.to(device)
430
+
431
+ generated_waves = []
432
+ spectrograms = []
433
+
434
+ if len(ref_text[-1].encode("utf-8")) == 1:
435
+ ref_text = ref_text + " "
436
+ for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
437
+ # Prepare the text
438
+ text_list = [ref_text + gen_text]
439
+ final_text_list = convert_char_to_pinyin(text_list)
440
+
441
+ ref_audio_len = audio.shape[-1] // hop_length
442
+ if fix_duration is not None:
443
+ duration = int(fix_duration * target_sample_rate / hop_length)
444
+ else:
445
+ # Calculate duration
446
+ ref_text_len = len(ref_text.encode("utf-8"))
447
+ gen_text_len = len(gen_text.encode("utf-8"))
448
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
449
+
450
+ # inference
451
+ with torch.inference_mode():
452
+ generated, _ = model_obj.sample(
453
+ cond=audio,
454
+ text=final_text_list,
455
+ duration=duration,
456
+ steps=nfe_step,
457
+ cfg_strength=cfg_strength,
458
+ sway_sampling_coef=sway_sampling_coef,
459
+ )
460
+
461
+ generated = generated.to(torch.float32)
462
+ generated = generated[:, ref_audio_len:, :]
463
+ generated_mel_spec = generated.permute(0, 2, 1)
464
+ if mel_spec_type == "vocos":
465
+ generated_wave = vocoder.decode(generated_mel_spec)
466
+ elif mel_spec_type == "bigvgan":
467
+ generated_wave = vocoder(generated_mel_spec)
468
+ if rms < target_rms:
469
+ generated_wave = generated_wave * rms / target_rms
470
+
471
+ # wav -> numpy
472
+ generated_wave = generated_wave.squeeze().cpu().numpy()
473
+
474
+ generated_waves.append(generated_wave)
475
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
476
+
477
+ # Combine all generated waves with cross-fading
478
+ if cross_fade_duration <= 0:
479
+ # Simply concatenate
480
+ final_wave = np.concatenate(generated_waves)
481
+ else:
482
+ final_wave = generated_waves[0]
483
+ for i in range(1, len(generated_waves)):
484
+ prev_wave = final_wave
485
+ next_wave = generated_waves[i]
486
+
487
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
488
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
489
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
490
+
491
+ if cross_fade_samples <= 0:
492
+ # No overlap possible, concatenate
493
+ final_wave = np.concatenate([prev_wave, next_wave])
494
+ continue
495
+
496
+ # Overlapping parts
497
+ prev_overlap = prev_wave[-cross_fade_samples:]
498
+ next_overlap = next_wave[:cross_fade_samples]
499
+
500
+ # Fade out and fade in
501
+ fade_out = np.linspace(1, 0, cross_fade_samples)
502
+ fade_in = np.linspace(0, 1, cross_fade_samples)
503
+
504
+ # Cross-faded overlap
505
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
506
+
507
+ # Combine
508
+ new_wave = np.concatenate(
509
+ [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
510
+ )
511
+
512
+ final_wave = new_wave
513
+
514
+ # Create a combined spectrogram
515
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
516
+
517
+ return final_wave, target_sample_rate, combined_spectrogram
518
+
519
+
520
+ # remove silence from generated wav
521
+
522
+
523
+ def remove_silence_for_generated_wav(filename):
524
+ aseg = AudioSegment.from_file(filename)
525
+ non_silent_segs = silence.split_on_silence(
526
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500, seek_step=10
527
+ )
528
+ non_silent_wave = AudioSegment.silent(duration=0)
529
+ for non_silent_seg in non_silent_segs:
530
+ non_silent_wave += non_silent_seg
531
+ aseg = non_silent_wave
532
+ aseg.export(filename, format="wav")
533
+
534
+
535
+ # save spectrogram
536
+
537
+
538
+ def save_spectrogram(spectrogram, path):
539
+ plt.figure(figsize=(12, 4))
540
+ plt.imshow(spectrogram, origin="lower", aspect="auto")
541
+ plt.colorbar()
542
+ plt.savefig(path)
543
+ plt.close()
model.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from random import random
13
+ from typing import Callable
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch import nn
18
+ from torch.nn.utils.rnn import pad_sequence
19
+ from torchdiffeq import odeint
20
+
21
+ from model_modules import MelSpec
22
+ from model_utils import (
23
+ default,
24
+ exists,
25
+ lens_to_mask,
26
+ list_str_to_idx,
27
+ list_str_to_tensor,
28
+ mask_from_frac_lengths,
29
+ )
30
+
31
+
32
+ class CFM(nn.Module):
33
+ def __init__(
34
+ self,
35
+ transformer: nn.Module,
36
+ sigma=0.0,
37
+ odeint_kwargs: dict = dict(
38
+ # atol = 1e-5,
39
+ # rtol = 1e-5,
40
+ method="euler" # 'midpoint'
41
+ ),
42
+ audio_drop_prob=0.3,
43
+ cond_drop_prob=0.2,
44
+ num_channels=None,
45
+ mel_spec_module: nn.Module | None = None,
46
+ mel_spec_kwargs: dict = dict(),
47
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
48
+ vocab_char_map: dict[str:int] | None = None,
49
+ ):
50
+ super().__init__()
51
+
52
+ self.frac_lengths_mask = frac_lengths_mask
53
+
54
+ # mel spec
55
+ self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
56
+ num_channels = default(num_channels, self.mel_spec.n_mel_channels)
57
+ self.num_channels = num_channels
58
+
59
+ # classifier-free guidance
60
+ self.audio_drop_prob = audio_drop_prob
61
+ self.cond_drop_prob = cond_drop_prob
62
+
63
+ # transformer
64
+ self.transformer = transformer
65
+ dim = transformer.dim
66
+ self.dim = dim
67
+
68
+ # conditional flow related
69
+ self.sigma = sigma
70
+
71
+ # sampling related
72
+ self.odeint_kwargs = odeint_kwargs
73
+
74
+ # vocab map for tokenization
75
+ self.vocab_char_map = vocab_char_map
76
+
77
+ @property
78
+ def device(self):
79
+ return next(self.parameters()).device
80
+
81
+ @torch.no_grad()
82
+ def sample(
83
+ self,
84
+ cond: float["b n d"] | float["b nw"], # noqa: F722
85
+ text: int["b nt"] | list[str], # noqa: F722
86
+ duration: int | int["b"], # noqa: F821
87
+ *,
88
+ lens: int["b"] | None = None, # noqa: F821
89
+ steps=32,
90
+ cfg_strength=1.0,
91
+ sway_sampling_coef=None,
92
+ seed: int | None = None,
93
+ max_duration=4096,
94
+ vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
95
+ no_ref_audio=False,
96
+ duplicate_test=False,
97
+ t_inter=0.1,
98
+ edit_mask=None,
99
+ ):
100
+ self.eval()
101
+ # raw wave
102
+
103
+ if cond.ndim == 2:
104
+ cond = self.mel_spec(cond)
105
+ cond = cond.permute(0, 2, 1)
106
+ assert cond.shape[-1] == self.num_channels
107
+
108
+ cond = cond.to(next(self.parameters()).dtype)
109
+
110
+ batch, cond_seq_len, device = *cond.shape[:2], cond.device
111
+ if not exists(lens):
112
+ lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
113
+
114
+ # text
115
+
116
+ if isinstance(text, list):
117
+ if exists(self.vocab_char_map):
118
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
119
+ else:
120
+ text = list_str_to_tensor(text).to(device)
121
+ assert text.shape[0] == batch
122
+
123
+ if exists(text):
124
+ text_lens = (text != -1).sum(dim=-1)
125
+ lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
126
+
127
+ # duration
128
+
129
+ cond_mask = lens_to_mask(lens)
130
+ if edit_mask is not None:
131
+ cond_mask = cond_mask & edit_mask
132
+
133
+ if isinstance(duration, int):
134
+ duration = torch.full((batch,), duration, device=device, dtype=torch.long)
135
+
136
+ duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
137
+ duration = duration.clamp(max=max_duration)
138
+ max_duration = duration.amax()
139
+
140
+ # duplicate test corner for inner time step oberservation
141
+ if duplicate_test:
142
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
143
+
144
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
145
+ cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
146
+ cond_mask = cond_mask.unsqueeze(-1)
147
+ step_cond = torch.where(
148
+ cond_mask, cond, torch.zeros_like(cond)
149
+ ) # allow direct control (cut cond audio) with lens passed in
150
+
151
+ if batch > 1:
152
+ mask = lens_to_mask(duration)
153
+ else: # save memory and speed up, as single inference need no mask currently
154
+ mask = None
155
+
156
+ # test for no ref audio
157
+ if no_ref_audio:
158
+ cond = torch.zeros_like(cond)
159
+
160
+ # neural ode
161
+
162
+ def fn(t, x):
163
+ # at each step, conditioning is fixed
164
+ # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
165
+
166
+ # predict flow
167
+ pred = self.transformer(
168
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
169
+ )
170
+ if cfg_strength < 1e-5:
171
+ return pred
172
+
173
+ null_pred = self.transformer(
174
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
175
+ )
176
+ return pred + (pred - null_pred) * cfg_strength
177
+
178
+ # noise input
179
+ # to make sure batch inference result is same with different batch size, and for sure single inference
180
+ # still some difference maybe due to convolutional layers
181
+ y0 = []
182
+ for dur in duration:
183
+ if exists(seed):
184
+ torch.manual_seed(seed)
185
+ y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
186
+ y0 = pad_sequence(y0, padding_value=0, batch_first=True)
187
+
188
+ t_start = 0
189
+
190
+ # duplicate test corner for inner time step oberservation
191
+ if duplicate_test:
192
+ t_start = t_inter
193
+ y0 = (1 - t_start) * y0 + t_start * test_cond
194
+ steps = int(steps * (1 - t_start))
195
+
196
+ t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
197
+ if sway_sampling_coef is not None:
198
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
199
+
200
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
201
+
202
+ sampled = trajectory[-1]
203
+ out = sampled
204
+ out = torch.where(cond_mask, cond, out)
205
+
206
+ if exists(vocoder):
207
+ out = out.permute(0, 2, 1)
208
+ out = vocoder(out)
209
+
210
+ return out, trajectory
211
+
212
+ def forward(
213
+ self,
214
+ inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
215
+ text: int["b nt"] | list[str], # noqa: F722
216
+ *,
217
+ lens: int["b"] | None = None, # noqa: F821
218
+ noise_scheduler: str | None = None,
219
+ ):
220
+ # handle raw wave
221
+ if inp.ndim == 2:
222
+ inp = self.mel_spec(inp)
223
+ inp = inp.permute(0, 2, 1)
224
+ assert inp.shape[-1] == self.num_channels
225
+
226
+ batch, seq_len, dtype, device, _Οƒ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
227
+
228
+ # handle text as string
229
+ if isinstance(text, list):
230
+ if exists(self.vocab_char_map):
231
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
232
+ else:
233
+ text = list_str_to_tensor(text).to(device)
234
+ assert text.shape[0] == batch
235
+
236
+ # lens and mask
237
+ if not exists(lens):
238
+ lens = torch.full((batch,), seq_len, device=device)
239
+
240
+ mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
241
+
242
+ # get a random span to mask out for training conditionally
243
+ frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
244
+ rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
245
+
246
+ if exists(mask):
247
+ rand_span_mask &= mask
248
+
249
+ # mel is x1
250
+ x1 = inp
251
+
252
+ # x0 is gaussian noise
253
+ x0 = torch.randn_like(x1)
254
+
255
+ # time step
256
+ time = torch.rand((batch,), dtype=dtype, device=self.device)
257
+ # TODO. noise_scheduler
258
+
259
+ # sample xt (Ο†_t(x) in the paper)
260
+ t = time.unsqueeze(-1).unsqueeze(-1)
261
+ Ο† = (1 - t) * x0 + t * x1
262
+ flow = x1 - x0
263
+
264
+ # only predict what is within the random mask span for infilling
265
+ cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
266
+
267
+ # transformer and cfg training with a drop rate
268
+ drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
269
+ if random() < self.cond_drop_prob: # p_uncond in voicebox paper
270
+ drop_audio_cond = True
271
+ drop_text = True
272
+ else:
273
+ drop_text = False
274
+
275
+ # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
276
+ # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
277
+ pred = self.transformer(
278
+ x=Ο†, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
279
+ )
280
+
281
+ # flow matching loss
282
+ loss = F.mse_loss(pred, flow, reduction="none")
283
+ loss = loss[rand_span_mask]
284
+
285
+ return loss.mean(), cond, pred
model_modules.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+ from typing import Optional
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+ from librosa.filters import mel as librosa_mel_fn
19
+ from torch import nn
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+
23
+ # raw wav to mel spec
24
+
25
+
26
+ mel_basis_cache = {}
27
+ hann_window_cache = {}
28
+
29
+
30
+ def get_bigvgan_mel_spectrogram(
31
+ waveform,
32
+ n_fft=1024,
33
+ n_mel_channels=100,
34
+ target_sample_rate=24000,
35
+ hop_length=256,
36
+ win_length=1024,
37
+ fmin=0,
38
+ fmax=None,
39
+ center=False,
40
+ ): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
41
+ device = waveform.device
42
+ key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
43
+
44
+ if key not in mel_basis_cache:
45
+ mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
46
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
47
+ hann_window_cache[key] = torch.hann_window(win_length).to(device)
48
+
49
+ mel_basis = mel_basis_cache[key]
50
+ hann_window = hann_window_cache[key]
51
+
52
+ padding = (n_fft - hop_length) // 2
53
+ waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
54
+
55
+ spec = torch.stft(
56
+ waveform,
57
+ n_fft,
58
+ hop_length=hop_length,
59
+ win_length=win_length,
60
+ window=hann_window,
61
+ center=center,
62
+ pad_mode="reflect",
63
+ normalized=False,
64
+ onesided=True,
65
+ return_complex=True,
66
+ )
67
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
68
+
69
+ mel_spec = torch.matmul(mel_basis, spec)
70
+ mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
71
+
72
+ return mel_spec
73
+
74
+
75
+ def get_vocos_mel_spectrogram(
76
+ waveform,
77
+ n_fft=1024,
78
+ n_mel_channels=100,
79
+ target_sample_rate=24000,
80
+ hop_length=256,
81
+ win_length=1024,
82
+ ):
83
+ mel_stft = torchaudio.transforms.MelSpectrogram(
84
+ sample_rate=target_sample_rate,
85
+ n_fft=n_fft,
86
+ win_length=win_length,
87
+ hop_length=hop_length,
88
+ n_mels=n_mel_channels,
89
+ power=1,
90
+ center=True,
91
+ normalized=False,
92
+ norm=None,
93
+ ).to(waveform.device)
94
+ if len(waveform.shape) == 3:
95
+ waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
96
+
97
+ assert len(waveform.shape) == 2
98
+
99
+ mel = mel_stft(waveform)
100
+ mel = mel.clamp(min=1e-5).log()
101
+ return mel
102
+
103
+
104
+ class MelSpec(nn.Module):
105
+ def __init__(
106
+ self,
107
+ n_fft=1024,
108
+ hop_length=256,
109
+ win_length=1024,
110
+ n_mel_channels=100,
111
+ target_sample_rate=24_000,
112
+ mel_spec_type="vocos",
113
+ ):
114
+ super().__init__()
115
+ assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
116
+
117
+ self.n_fft = n_fft
118
+ self.hop_length = hop_length
119
+ self.win_length = win_length
120
+ self.n_mel_channels = n_mel_channels
121
+ self.target_sample_rate = target_sample_rate
122
+
123
+ if mel_spec_type == "vocos":
124
+ self.extractor = get_vocos_mel_spectrogram
125
+ elif mel_spec_type == "bigvgan":
126
+ self.extractor = get_bigvgan_mel_spectrogram
127
+
128
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
129
+
130
+ def forward(self, wav):
131
+ if self.dummy.device != wav.device:
132
+ self.to(wav.device)
133
+
134
+ mel = self.extractor(
135
+ waveform=wav,
136
+ n_fft=self.n_fft,
137
+ n_mel_channels=self.n_mel_channels,
138
+ target_sample_rate=self.target_sample_rate,
139
+ hop_length=self.hop_length,
140
+ win_length=self.win_length,
141
+ )
142
+
143
+ return mel
144
+
145
+
146
+ # sinusoidal position embedding
147
+
148
+
149
+ class SinusPositionEmbedding(nn.Module):
150
+ def __init__(self, dim):
151
+ super().__init__()
152
+ self.dim = dim
153
+
154
+ def forward(self, x, scale=1000):
155
+ device = x.device
156
+ half_dim = self.dim // 2
157
+ emb = math.log(10000) / (half_dim - 1)
158
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
159
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
160
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
161
+ return emb
162
+
163
+
164
+ # convolutional position embedding
165
+
166
+
167
+ class ConvPositionEmbedding(nn.Module):
168
+ def __init__(self, dim, kernel_size=31, groups=16):
169
+ super().__init__()
170
+ assert kernel_size % 2 != 0
171
+ self.conv1d = nn.Sequential(
172
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
173
+ nn.Mish(),
174
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
175
+ nn.Mish(),
176
+ )
177
+
178
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
179
+ if mask is not None:
180
+ mask = mask[..., None]
181
+ x = x.masked_fill(~mask, 0.0)
182
+
183
+ x = x.permute(0, 2, 1)
184
+ x = self.conv1d(x)
185
+ out = x.permute(0, 2, 1)
186
+
187
+ if mask is not None:
188
+ out = out.masked_fill(~mask, 0.0)
189
+
190
+ return out
191
+
192
+
193
+ # rotary positional embedding related
194
+
195
+
196
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
197
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
198
+ # has some connection to NTK literature
199
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
200
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
201
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
202
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
203
+ t = torch.arange(end, device=freqs.device) # type: ignore
204
+ freqs = torch.outer(t, freqs).float() # type: ignore
205
+ freqs_cos = torch.cos(freqs) # real part
206
+ freqs_sin = torch.sin(freqs) # imaginary part
207
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
208
+
209
+
210
+ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
211
+ # length = length if isinstance(length, int) else length.max()
212
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
213
+ pos = (
214
+ start.unsqueeze(1)
215
+ + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
216
+ )
217
+ # avoid extra long error.
218
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
219
+ return pos
220
+
221
+
222
+ # Global Response Normalization layer (Instance Normalization ?)
223
+
224
+
225
+ class GRN(nn.Module):
226
+ def __init__(self, dim):
227
+ super().__init__()
228
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
229
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
230
+
231
+ def forward(self, x):
232
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
233
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
234
+ return self.gamma * (x * Nx) + self.beta + x
235
+
236
+
237
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
238
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
239
+
240
+
241
+ class ConvNeXtV2Block(nn.Module):
242
+ def __init__(
243
+ self,
244
+ dim: int,
245
+ intermediate_dim: int,
246
+ dilation: int = 1,
247
+ ):
248
+ super().__init__()
249
+ padding = (dilation * (7 - 1)) // 2
250
+ self.dwconv = nn.Conv1d(
251
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
252
+ ) # depthwise conv
253
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
254
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
255
+ self.act = nn.GELU()
256
+ self.grn = GRN(intermediate_dim)
257
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
258
+
259
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
260
+ residual = x
261
+ x = x.transpose(1, 2) # b n d -> b d n
262
+ x = self.dwconv(x)
263
+ x = x.transpose(1, 2) # b d n -> b n d
264
+ x = self.norm(x)
265
+ x = self.pwconv1(x)
266
+ x = self.act(x)
267
+ x = self.grn(x)
268
+ x = self.pwconv2(x)
269
+ return residual + x
270
+
271
+
272
+ # AdaLayerNormZero
273
+ # return with modulated x for attn input, and params for later mlp modulation
274
+
275
+
276
+ class AdaLayerNormZero(nn.Module):
277
+ def __init__(self, dim):
278
+ super().__init__()
279
+
280
+ self.silu = nn.SiLU()
281
+ self.linear = nn.Linear(dim, dim * 6)
282
+
283
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
284
+
285
+ def forward(self, x, emb=None):
286
+ emb = self.linear(self.silu(emb))
287
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
288
+
289
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
290
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
291
+
292
+
293
+ # AdaLayerNormZero for final layer
294
+ # return only with modulated x for attn input, cuz no more mlp modulation
295
+
296
+
297
+ class AdaLayerNormZero_Final(nn.Module):
298
+ def __init__(self, dim):
299
+ super().__init__()
300
+
301
+ self.silu = nn.SiLU()
302
+ self.linear = nn.Linear(dim, dim * 2)
303
+
304
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
305
+
306
+ def forward(self, x, emb):
307
+ emb = self.linear(self.silu(emb))
308
+ scale, shift = torch.chunk(emb, 2, dim=1)
309
+
310
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
311
+ return x
312
+
313
+
314
+ # FeedForward
315
+
316
+
317
+ class FeedForward(nn.Module):
318
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
319
+ super().__init__()
320
+ inner_dim = int(dim * mult)
321
+ dim_out = dim_out if dim_out is not None else dim
322
+
323
+ activation = nn.GELU(approximate=approximate)
324
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
325
+ self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
326
+
327
+ def forward(self, x):
328
+ return self.ff(x)
329
+
330
+
331
+ # Attention with possible joint part
332
+ # modified from diffusers/src/diffusers/models/attention_processor.py
333
+
334
+
335
+ class Attention(nn.Module):
336
+ def __init__(
337
+ self,
338
+ processor: JointAttnProcessor | AttnProcessor,
339
+ dim: int,
340
+ heads: int = 8,
341
+ dim_head: int = 64,
342
+ dropout: float = 0.0,
343
+ context_dim: Optional[int] = None, # if not None -> joint attention
344
+ context_pre_only=None,
345
+ ):
346
+ super().__init__()
347
+
348
+ if not hasattr(F, "scaled_dot_product_attention"):
349
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
350
+
351
+ self.processor = processor
352
+
353
+ self.dim = dim
354
+ self.heads = heads
355
+ self.inner_dim = dim_head * heads
356
+ self.dropout = dropout
357
+
358
+ self.context_dim = context_dim
359
+ self.context_pre_only = context_pre_only
360
+
361
+ self.to_q = nn.Linear(dim, self.inner_dim)
362
+ self.to_k = nn.Linear(dim, self.inner_dim)
363
+ self.to_v = nn.Linear(dim, self.inner_dim)
364
+
365
+ if self.context_dim is not None:
366
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
367
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
368
+ if self.context_pre_only is not None:
369
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
370
+
371
+ self.to_out = nn.ModuleList([])
372
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
373
+ self.to_out.append(nn.Dropout(dropout))
374
+
375
+ if self.context_pre_only is not None and not self.context_pre_only:
376
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
377
+
378
+ def forward(
379
+ self,
380
+ x: float["b n d"], # noised input x # noqa: F722
381
+ c: float["b n d"] = None, # context c # noqa: F722
382
+ mask: bool["b n"] | None = None, # noqa: F722
383
+ rope=None, # rotary position embedding for x
384
+ c_rope=None, # rotary position embedding for c
385
+ ) -> torch.Tensor:
386
+ if c is not None:
387
+ return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
388
+ else:
389
+ return self.processor(self, x, mask=mask, rope=rope)
390
+
391
+
392
+ # Attention processor
393
+
394
+
395
+ class AttnProcessor:
396
+ def __init__(self):
397
+ pass
398
+
399
+ def __call__(
400
+ self,
401
+ attn: Attention,
402
+ x: float["b n d"], # noised input x # noqa: F722
403
+ mask: bool["b n"] | None = None, # noqa: F722
404
+ rope=None, # rotary position embedding
405
+ ) -> torch.FloatTensor:
406
+ batch_size = x.shape[0]
407
+
408
+ # `sample` projections.
409
+ query = attn.to_q(x)
410
+ key = attn.to_k(x)
411
+ value = attn.to_v(x)
412
+
413
+ # apply rotary position embedding
414
+ if rope is not None:
415
+ freqs, xpos_scale = rope
416
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
417
+
418
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
419
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
420
+
421
+ # attention
422
+ inner_dim = key.shape[-1]
423
+ head_dim = inner_dim // attn.heads
424
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
425
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
426
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
427
+
428
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
429
+ if mask is not None:
430
+ attn_mask = mask
431
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
432
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
433
+ else:
434
+ attn_mask = None
435
+
436
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
437
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
438
+ x = x.to(query.dtype)
439
+
440
+ # linear proj
441
+ x = attn.to_out[0](x)
442
+ # dropout
443
+ x = attn.to_out[1](x)
444
+
445
+ if mask is not None:
446
+ mask = mask.unsqueeze(-1)
447
+ x = x.masked_fill(~mask, 0.0)
448
+
449
+ return x
450
+
451
+
452
+ # Joint Attention processor for MM-DiT
453
+ # modified from diffusers/src/diffusers/models/attention_processor.py
454
+
455
+
456
+ class JointAttnProcessor:
457
+ def __init__(self):
458
+ pass
459
+
460
+ def __call__(
461
+ self,
462
+ attn: Attention,
463
+ x: float["b n d"], # noised input x # noqa: F722
464
+ c: float["b nt d"] = None, # context c, here text # noqa: F722
465
+ mask: bool["b n"] | None = None, # noqa: F722
466
+ rope=None, # rotary position embedding for x
467
+ c_rope=None, # rotary position embedding for c
468
+ ) -> torch.FloatTensor:
469
+ residual = x
470
+
471
+ batch_size = c.shape[0]
472
+
473
+ # `sample` projections.
474
+ query = attn.to_q(x)
475
+ key = attn.to_k(x)
476
+ value = attn.to_v(x)
477
+
478
+ # `context` projections.
479
+ c_query = attn.to_q_c(c)
480
+ c_key = attn.to_k_c(c)
481
+ c_value = attn.to_v_c(c)
482
+
483
+ # apply rope for context and noised input independently
484
+ if rope is not None:
485
+ freqs, xpos_scale = rope
486
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
487
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
488
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
489
+ if c_rope is not None:
490
+ freqs, xpos_scale = c_rope
491
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
492
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
493
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
494
+
495
+ # attention
496
+ query = torch.cat([query, c_query], dim=1)
497
+ key = torch.cat([key, c_key], dim=1)
498
+ value = torch.cat([value, c_value], dim=1)
499
+
500
+ inner_dim = key.shape[-1]
501
+ head_dim = inner_dim // attn.heads
502
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
503
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
504
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
505
+
506
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
507
+ if mask is not None:
508
+ attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
509
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
510
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
511
+ else:
512
+ attn_mask = None
513
+
514
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
515
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
516
+ x = x.to(query.dtype)
517
+
518
+ # Split the attention outputs.
519
+ x, c = (
520
+ x[:, : residual.shape[1]],
521
+ x[:, residual.shape[1] :],
522
+ )
523
+
524
+ # linear proj
525
+ x = attn.to_out[0](x)
526
+ # dropout
527
+ x = attn.to_out[1](x)
528
+ if not attn.context_pre_only:
529
+ c = attn.to_out_c(c)
530
+
531
+ if mask is not None:
532
+ mask = mask.unsqueeze(-1)
533
+ x = x.masked_fill(~mask, 0.0)
534
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
535
+
536
+ return x, c
537
+
538
+
539
+ # DiT Block
540
+
541
+
542
+ class DiTBlock(nn.Module):
543
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
544
+ super().__init__()
545
+
546
+ self.attn_norm = AdaLayerNormZero(dim)
547
+ self.attn = Attention(
548
+ processor=AttnProcessor(),
549
+ dim=dim,
550
+ heads=heads,
551
+ dim_head=dim_head,
552
+ dropout=dropout,
553
+ )
554
+
555
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
556
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
557
+
558
+ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
559
+ # pre-norm & modulation for attention input
560
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
561
+
562
+ # attention
563
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
564
+
565
+ # process attention output for input x
566
+ x = x + gate_msa.unsqueeze(1) * attn_output
567
+
568
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
569
+ ff_output = self.ff(norm)
570
+ x = x + gate_mlp.unsqueeze(1) * ff_output
571
+
572
+ return x
573
+
574
+
575
+ # MMDiT Block https://arxiv.org/abs/2403.03206
576
+
577
+
578
+ class MMDiTBlock(nn.Module):
579
+ r"""
580
+ modified from diffusers/src/diffusers/models/attention.py
581
+
582
+ notes.
583
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
584
+ _x: noised input related. (right part)
585
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
586
+ """
587
+
588
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
589
+ super().__init__()
590
+
591
+ self.context_pre_only = context_pre_only
592
+
593
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
594
+ self.attn_norm_x = AdaLayerNormZero(dim)
595
+ self.attn = Attention(
596
+ processor=JointAttnProcessor(),
597
+ dim=dim,
598
+ heads=heads,
599
+ dim_head=dim_head,
600
+ dropout=dropout,
601
+ context_dim=dim,
602
+ context_pre_only=context_pre_only,
603
+ )
604
+
605
+ if not context_pre_only:
606
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
607
+ self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
608
+ else:
609
+ self.ff_norm_c = None
610
+ self.ff_c = None
611
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
612
+ self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
613
+
614
+ def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
615
+ # pre-norm & modulation for attention input
616
+ if self.context_pre_only:
617
+ norm_c = self.attn_norm_c(c, t)
618
+ else:
619
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
620
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
621
+
622
+ # attention
623
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
624
+
625
+ # process attention output for context c
626
+ if self.context_pre_only:
627
+ c = None
628
+ else: # if not last layer
629
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
630
+
631
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
632
+ c_ff_output = self.ff_c(norm_c)
633
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
634
+
635
+ # process attention output for input x
636
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
637
+
638
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
639
+ x_ff_output = self.ff_x(norm_x)
640
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
641
+
642
+ return c, x
643
+
644
+
645
+ # time step conditioning embedding
646
+
647
+
648
+ class TimestepEmbedding(nn.Module):
649
+ def __init__(self, dim, freq_embed_dim=256):
650
+ super().__init__()
651
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
652
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
653
+
654
+ def forward(self, timestep: float["b"]): # noqa: F821
655
+ time_hidden = self.time_embed(timestep)
656
+ time_hidden = time_hidden.to(timestep.dtype)
657
+ time = self.time_mlp(time_hidden) # b d
658
+ return time
model_utils.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ from collections import defaultdict
6
+ from importlib.resources import files
7
+
8
+ import torch
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+ import jieba
12
+ from pypinyin import lazy_pinyin, Style
13
+
14
+
15
+ # seed everything
16
+
17
+
18
+ def seed_everything(seed=0):
19
+ random.seed(seed)
20
+ os.environ["PYTHONHASHSEED"] = str(seed)
21
+ torch.manual_seed(seed)
22
+ torch.cuda.manual_seed(seed)
23
+ torch.cuda.manual_seed_all(seed)
24
+ torch.backends.cudnn.deterministic = True
25
+ torch.backends.cudnn.benchmark = False
26
+
27
+
28
+ # helpers
29
+
30
+
31
+ def exists(v):
32
+ return v is not None
33
+
34
+
35
+ def default(v, d):
36
+ return v if exists(v) else d
37
+
38
+
39
+ # tensor helpers
40
+
41
+
42
+ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
43
+ if not exists(length):
44
+ length = t.amax()
45
+
46
+ seq = torch.arange(length, device=t.device)
47
+ return seq[None, :] < t[:, None]
48
+
49
+
50
+ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
51
+ max_seq_len = seq_len.max().item()
52
+ seq = torch.arange(max_seq_len, device=start.device).long()
53
+ start_mask = seq[None, :] >= start[:, None]
54
+ end_mask = seq[None, :] < end[:, None]
55
+ return start_mask & end_mask
56
+
57
+
58
+ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
59
+ lengths = (frac_lengths * seq_len).long()
60
+ max_start = seq_len - lengths
61
+
62
+ rand = torch.rand_like(frac_lengths)
63
+ start = (max_start * rand).long().clamp(min=0)
64
+ end = start + lengths
65
+
66
+ return mask_from_start_end_indices(seq_len, start, end)
67
+
68
+
69
+ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
70
+ if not exists(mask):
71
+ return t.mean(dim=1)
72
+
73
+ t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
74
+ num = t.sum(dim=1)
75
+ den = mask.float().sum(dim=1)
76
+
77
+ return num / den.clamp(min=1.0)
78
+
79
+
80
+ # simple utf-8 tokenizer, since paper went character based
81
+ def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
82
+ list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
83
+ text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
84
+ return text
85
+
86
+
87
+ # char tokenizer, based on custom dataset's extracted .txt file
88
+ def list_str_to_idx(
89
+ text: list[str] | list[list[str]],
90
+ vocab_char_map: dict[str, int], # {char: idx}
91
+ padding_value=-1,
92
+ ) -> int["b nt"]: # noqa: F722
93
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
94
+ text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
95
+ return text
96
+
97
+
98
+ # Get tokenizer
99
+
100
+
101
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
102
+ """
103
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
104
+ - "char" for char-wise tokenizer, need .txt vocab_file
105
+ - "byte" for utf-8 tokenizer
106
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
107
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
108
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
109
+ - if use "byte", set to 256 (unicode byte range)
110
+ """
111
+ if tokenizer in ["pinyin", "char"]:
112
+ tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
113
+ with open(tokenizer_path, "r", encoding="utf-8") as f:
114
+ vocab_char_map = {}
115
+ for i, char in enumerate(f):
116
+ vocab_char_map[char[:-1]] = i
117
+ vocab_size = len(vocab_char_map)
118
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
119
+
120
+ elif tokenizer == "byte":
121
+ vocab_char_map = None
122
+ vocab_size = 256
123
+
124
+ elif tokenizer == "custom":
125
+ with open(dataset_name, "r", encoding="utf-8") as f:
126
+ vocab_char_map = {}
127
+ for i, char in enumerate(f):
128
+ vocab_char_map[char[:-1]] = i
129
+ vocab_size = len(vocab_char_map)
130
+
131
+ return vocab_char_map, vocab_size
132
+
133
+
134
+ # convert char to pinyin
135
+
136
+
137
+ def convert_char_to_pinyin(text_list, polyphone=True):
138
+ final_text_list = []
139
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans(
140
+ {"β€œ": '"', "”": '"', "β€˜": "'", "’": "'"}
141
+ ) # in case librispeech (orig no-pc) test-clean
142
+ custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
143
+ for text in text_list:
144
+ char_list = []
145
+ text = text.translate(god_knows_why_en_testset_contains_zh_quote)
146
+ text = text.translate(custom_trans)
147
+ for seg in jieba.cut(text):
148
+ seg_byte_len = len(bytes(seg, "UTF-8"))
149
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
150
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
151
+ char_list.append(" ")
152
+ char_list.extend(seg)
153
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
154
+ seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
155
+ for c in seg:
156
+ if c not in "γ€‚οΌŒγ€οΌ›οΌšοΌŸοΌγ€Šγ€‹γ€γ€‘β€”β€¦":
157
+ char_list.append(" ")
158
+ char_list.append(c)
159
+ else: # if mixed chinese characters, alphabets and symbols
160
+ for c in seg:
161
+ if ord(c) < 256:
162
+ char_list.extend(c)
163
+ elif '\u0400' <= c <= '\u04FF': # Cyrillic Unicode block
164
+ char_list.extend(c)
165
+ else:
166
+ if c not in "γ€‚οΌŒγ€οΌ›οΌšοΌŸοΌγ€Šγ€‹γ€γ€‘β€”β€¦":
167
+ char_list.append(" ")
168
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
169
+ else: # if is zh punc
170
+ char_list.append(c)
171
+ final_text_list.append(char_list)
172
+
173
+ return final_text_list
174
+
175
+
176
+ # filter func for dirty data with many repetitions
177
+
178
+
179
+ def repetition_found(text, length=2, tolerance=10):
180
+ pattern_count = defaultdict(int)
181
+ for i in range(len(text) - length + 1):
182
+ pattern = text[i : i + length]
183
+ pattern_count[pattern] += 1
184
+ for pattern, count in pattern_count.items():
185
+ if count > tolerance:
186
+ return True
187
+ return False
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.33.0
2
+ bitsandbytes>0.37.0
3
+ cached_path
4
+ click
5
+ datasets
6
+ ema_pytorch>=0.5.2
7
+ gradio>=3.45.2
8
+ hydra-core>=1.3.0
9
+ jieba
10
+ librosa
11
+ matplotlib
12
+ numpy<=1.26.4
13
+ pydub
14
+ pypinyin
15
+ safetensors
16
+ soundfile
17
+ tomli
18
+ torch>=2.0.0
19
+ torchaudio>=2.0.0
20
+ torchdiffeq
21
+ tqdm>=4.65.0
22
+ transformers
23
+ transformers_stream_generator
24
+ vocos
25
+ wandb
26
+ x_transformers>=1.31.14