PoTaTo721 commited on
Commit
b2eb230
·
1 Parent(s): b3355c2

Update to V1.5

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +286 -340
  2. fish_speech/callbacks/__init__.py +3 -3
  3. fish_speech/callbacks/grad_norm.py +113 -113
  4. fish_speech/configs/base.yaml +87 -87
  5. fish_speech/configs/firefly_gan_vq.yaml +33 -33
  6. fish_speech/configs/lora/r_8_alpha_16.yaml +4 -4
  7. fish_speech/configs/model/dual_ar_2_codebook_large.yaml +0 -9
  8. fish_speech/configs/model/dual_ar_2_codebook_medium.yaml +0 -9
  9. fish_speech/configs/model/dual_ar_2_codebook_small.yaml +0 -13
  10. fish_speech/configs/model/naive_2_codebook_small.yaml +0 -12
  11. fish_speech/configs/text2semantic_finetune.yaml +83 -83
  12. fish_speech/configs/text2semantic_finetune_lora.yaml +0 -13
  13. fish_speech/configs/text2semantic_pretrain.yaml +0 -74
  14. fish_speech/configs/text2semantic_sft.yaml +0 -87
  15. fish_speech/configs/vqgan_finetune.yaml +0 -135
  16. fish_speech/configs/vqgan_pretrain.yaml +0 -139
  17. fish_speech/conversation.py +267 -2
  18. fish_speech/datasets/concat_repeat.py +53 -53
  19. fish_speech/datasets/protos/text-data.proto +24 -24
  20. fish_speech/datasets/protos/text_data_pb2.py +33 -33
  21. fish_speech/datasets/protos/text_data_stream.py +36 -36
  22. fish_speech/datasets/semantic.py +496 -496
  23. fish_speech/datasets/text.py +0 -661
  24. fish_speech/datasets/vqgan.py +147 -147
  25. fish_speech/i18n/README.md +27 -27
  26. fish_speech/i18n/__init__.py +3 -3
  27. fish_speech/i18n/core.py +40 -40
  28. fish_speech/i18n/locale/en_US.json +123 -122
  29. fish_speech/i18n/locale/es_ES.json +123 -122
  30. fish_speech/i18n/locale/ja_JP.json +123 -123
  31. fish_speech/i18n/locale/ko_KR.json +123 -0
  32. fish_speech/i18n/locale/pt_BR.json +133 -133
  33. fish_speech/i18n/locale/zh_CN.json +123 -122
  34. fish_speech/i18n/scan.py +122 -122
  35. fish_speech/models/text2semantic/lit_module.py +202 -202
  36. fish_speech/models/text2semantic/llama.py +887 -779
  37. fish_speech/models/text2semantic/lora.py +92 -92
  38. fish_speech/models/vqgan/lit_module.py +0 -442
  39. fish_speech/models/vqgan/modules/discriminator.py +0 -44
  40. fish_speech/models/vqgan/modules/firefly.py +596 -596
  41. fish_speech/models/vqgan/modules/fsq.py +116 -116
  42. fish_speech/models/vqgan/modules/reference.py +0 -113
  43. fish_speech/models/vqgan/modules/wavenet.py +0 -225
  44. fish_speech/models/vqgan/spectrogram.py +0 -122
  45. fish_speech/models/vqgan/utils.py +94 -94
  46. fish_speech/scheduler.py +40 -40
  47. fish_speech/text/__init__.py +4 -4
  48. fish_speech/text/chn_text_norm/.gitignore +114 -114
  49. fish_speech/text/chn_text_norm/README.md +36 -36
  50. fish_speech/text/chn_text_norm/basic_class.py +172 -172
app.py CHANGED
@@ -10,7 +10,7 @@ import gc
10
 
11
  # Download if not exists
12
  os.makedirs("checkpoints", exist_ok=True)
13
- snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4")
14
 
15
  print("All checkpoints downloaded")
16
 
@@ -31,11 +31,11 @@ torchaudio.set_audio_backend("soundfile")
31
  from loguru import logger
32
  from transformers import AutoTokenizer
33
 
34
- from tools.llama.generate import launch_thread_safe_queue
35
- from tools.vqgan.inference import load_model as load_vqgan_model
36
  from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
 
37
  from tools.api import decode_vq_tokens, encode_reference
38
- from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
39
  from tools.llama.generate import (
40
  GenerateRequest,
41
  GenerateResponse,
@@ -44,20 +44,43 @@ from tools.llama.generate import (
44
  )
45
  from tools.vqgan.inference import load_model as load_decoder_model
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Make einx happy
48
  os.environ["EINX_FILTER_TRACEBACK"] = "false"
49
 
50
 
51
  HEADER_MD = """# Fish Speech
52
 
53
- ## The demo in this space is version 1.4, Please check [Fish Audio](https://fish.audio) for the best model.
54
- ## 该 Demo 为 Fish Speech 1.4 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
55
 
56
  A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
57
  由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
58
 
59
- You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).
60
- 你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.4) 找到模型.
61
 
62
  Related code and weights are released under CC BY-NC-SA 4.0 License.
63
  相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
@@ -65,8 +88,8 @@ Related code and weights are released under CC BY-NC-SA 4.0 License.
65
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
66
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
67
 
68
- The model running in this WebUI is Fish Speech V1.4 Medium.
69
- 在此 WebUI 中运行的模型是 Fish Speech V1.4 Medium.
70
  """
71
 
72
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
@@ -95,48 +118,77 @@ def build_html_error_message(error):
95
 
96
  @GPU_DECORATOR
97
  @torch.inference_mode()
98
- def inference(
99
- text,
100
- enable_reference_audio,
101
- reference_audio,
102
- reference_text,
103
- max_new_tokens,
104
- chunk_length,
105
- top_p,
106
- repetition_penalty,
107
- temperature,
108
- streaming=False
109
- ):
110
- if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
111
- return (
112
- None,
113
- None,
114
- "Text is too long, please keep it under {} characters.".format(
115
- args.max_gradio_length
116
- ),
117
  )
118
 
119
- # Parse reference audio aka prompt
120
- prompt_tokens = encode_reference(
121
- decoder_model=decoder_model,
122
- reference_audio=reference_audio,
123
- enable_reference_audio=enable_reference_audio,
124
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # LLAMA Inference
127
  request = dict(
128
  device=decoder_model.device,
129
- max_new_tokens=max_new_tokens,
130
- text=text,
131
- top_p=top_p,
132
- repetition_penalty=repetition_penalty,
133
- temperature=temperature,
 
 
 
 
134
  compile=args.compile,
135
- iterative_prompt=chunk_length > 0,
136
- chunk_length=chunk_length,
137
- max_length=2048,
138
- prompt_tokens=prompt_tokens if enable_reference_audio else None,
139
- prompt_text=reference_text if enable_reference_audio else None,
140
  )
141
 
142
  response_queue = queue.Queue()
@@ -152,19 +204,15 @@ def inference(
152
  while True:
153
  result: WrappedGenerateResponse = response_queue.get()
154
  if result.status == "error":
155
- return None, None, build_html_error_message(result.response)
 
156
 
157
  result: GenerateResponse = result.response
158
  if result.action == "next":
159
  break
160
 
161
- with torch.autocast(
162
- device_type=(
163
- "cpu"
164
- if decoder_model.device.type == "mps"
165
- else decoder_model.device.type
166
- ),
167
- dtype=args.precision,
168
  ):
169
  fake_audios = decode_vq_tokens(
170
  decoder_model=decoder_model,
@@ -179,79 +227,24 @@ def inference(
179
  None,
180
  None,
181
  build_html_error_message(
182
- "No audio generated, please check the input text."
183
  ),
184
  )
185
 
186
- # Return the final audio
187
  audio = np.concatenate(segments, axis=0)
188
- return None, (decoder_model.spec_transform.sample_rate, audio), None
189
 
190
  if torch.cuda.is_available():
191
  torch.cuda.empty_cache()
192
  gc.collect()
193
 
194
-
195
- def inference_with_auto_rerank(
196
- text,
197
- enable_reference_audio,
198
- reference_audio,
199
- reference_text,
200
- max_new_tokens,
201
- chunk_length,
202
- top_p,
203
- repetition_penalty,
204
- temperature,
205
- use_auto_rerank,
206
- streaming=False,
207
- ):
208
- max_attempts = 2 if use_auto_rerank else 1
209
- best_wer = float("inf")
210
- best_audio = None
211
- best_sample_rate = None
212
-
213
- for attempt in range(max_attempts):
214
- _, (sample_rate, audio), message = inference(
215
- text,
216
- enable_reference_audio,
217
- reference_audio,
218
- reference_text,
219
- max_new_tokens,
220
- chunk_length,
221
- top_p,
222
- repetition_penalty,
223
- temperature,
224
- streaming=False,
225
- )
226
-
227
- if audio is None:
228
- return None, None, message
229
-
230
- if not use_auto_rerank:
231
- return None, (sample_rate, audio), None
232
-
233
- asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
234
- wer = calculate_wer(text, asr_result["text"])
235
-
236
- if wer <= 0.3 and not asr_result["huge_gap"]:
237
- return None, (sample_rate, audio), None
238
-
239
- if wer < best_wer:
240
- best_wer = wer
241
- best_audio = audio
242
- best_sample_rate = sample_rate
243
-
244
- if attempt == max_attempts - 1:
245
- break
246
-
247
- return None, (best_sample_rate, best_audio), None
248
-
249
-
250
  n_audios = 4
251
 
252
  global_audio_list = []
253
  global_error_list = []
254
 
 
255
  def inference_wrapper(
256
  text,
257
  enable_reference_audio,
@@ -262,14 +255,14 @@ def inference_wrapper(
262
  top_p,
263
  repetition_penalty,
264
  temperature,
 
265
  batch_infer_num,
266
- if_load_asr_model,
267
  ):
268
  audios = []
269
  errors = []
270
 
271
  for _ in range(batch_infer_num):
272
- result = inference_with_auto_rerank(
273
  text,
274
  enable_reference_audio,
275
  reference_audio,
@@ -279,10 +272,10 @@ def inference_wrapper(
279
  top_p,
280
  repetition_penalty,
281
  temperature,
282
- if_load_asr_model,
283
  )
284
 
285
- _, audio_data, error_message = result
286
 
287
  audios.append(
288
  gr.Audio(value=audio_data if audio_data else None, visible=True),
@@ -314,52 +307,17 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
314
  buffer.close()
315
  return wav_header_bytes
316
 
317
-
318
  def normalize_text(user_input, use_normalization):
319
  if use_normalization:
320
  return ChnNormedText(raw_text=user_input).normalize()
321
  else:
322
  return user_input
323
 
324
-
325
- asr_model = None
326
-
327
-
328
- def change_if_load_asr_model(if_load):
329
- global asr_model
330
-
331
- if if_load:
332
- gr.Warning("Loading faster whisper model...")
333
- if asr_model is None:
334
- asr_model = load_model()
335
- return gr.Checkbox(label="Unload faster whisper model", value=if_load)
336
-
337
- if if_load is False:
338
- gr.Warning("Unloading faster whisper model...")
339
- del asr_model
340
- asr_model = None
341
- if torch.cuda.is_available():
342
- torch.cuda.empty_cache()
343
- gc.collect()
344
- return gr.Checkbox(label="Load faster whisper model", value=if_load)
345
-
346
-
347
- def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
348
- if if_load and asr_model is not None:
349
- if (
350
- if_auto_label
351
- and enable_ref
352
- and ref_audio is not None
353
- and ref_text.strip() == ""
354
- ):
355
- data, sample_rate = librosa.load(ref_audio)
356
- res = batch_asr(asr_model, [data], sample_rate)[0]
357
- ref_text = res["text"]
358
- else:
359
- gr.Warning("Whisper model not loaded!")
360
-
361
- return gr.Textbox(value=ref_text)
362
-
363
 
364
  def build_app():
365
  with gr.Blocks(theme=gr.themes.Base()) as app:
@@ -377,202 +335,185 @@ def build_app():
377
  with gr.Row():
378
  with gr.Column(scale=3):
379
  text = gr.Textbox(
380
- label="Input Text", placeholder=TEXTBOX_PLACEHOLDER, lines=10
381
  )
382
  refined_text = gr.Textbox(
383
- label="Realtime Transform Text",
384
- placeholder=
385
- "Normalization Result Preview (Currently Only Chinese)",
 
386
  lines=5,
387
  interactive=False,
388
  )
389
 
390
  with gr.Row():
391
- if_refine_text = gr.Checkbox(
392
- label="Text Normalization (ZH)",
393
- value=False,
394
- scale=1,
395
- )
396
-
397
- if_load_asr_model = gr.Checkbox(
398
- label="Load / Unload ASR model for auto-reranking",
399
  value=False,
400
- scale=3,
401
  )
402
 
403
  with gr.Row():
404
- with gr.Tab(label="Advanced Config"):
405
- chunk_length = gr.Slider(
406
- label="Iterative Prompt Length, 0 means off",
407
- minimum=0,
408
- maximum=500,
409
- value=200,
410
- step=8,
411
- )
412
-
413
- max_new_tokens = gr.Slider(
414
- label="Maximum tokens per batch, 0 means no limit",
415
- minimum=0,
416
- maximum=2048,
417
- value=0, # 0 means no limit
418
- step=8,
419
- )
420
-
421
- top_p = gr.Slider(
422
- label="Top-P",
423
- minimum=0.6,
424
- maximum=0.9,
425
- value=0.7,
426
- step=0.01,
427
- )
428
-
429
- repetition_penalty = gr.Slider(
430
- label="Repetition Penalty",
431
- minimum=1,
432
- maximum=1.5,
433
- value=1.2,
434
- step=0.01,
435
- )
436
-
437
- temperature = gr.Slider(
438
- label="Temperature",
439
- minimum=0.6,
440
- maximum=0.9,
441
- value=0.7,
442
- step=0.01,
443
- )
444
-
445
- with gr.Tab(label="Reference Audio"):
446
- gr.Markdown(
447
- "5 to 10 seconds of reference audio, useful for specifying speaker."
448
- )
449
-
450
- enable_reference_audio = gr.Checkbox(
451
- label="Enable Reference Audio",
452
- )
453
-
454
- # Add dropdown for selecting example audio files
455
- example_audio_files = [f for f in os.listdir("examples") if f.endswith(".wav")]
456
- example_audio_dropdown = gr.Dropdown(
457
- label="Select Example Audio",
458
- choices=[""] + example_audio_files,
459
- value=""
460
- )
461
-
462
- reference_audio = gr.Audio(
463
- label="Reference Audio",
464
- type="filepath",
465
- )
466
- with gr.Row():
467
- if_auto_label = gr.Checkbox(
468
- label="Auto Labeling",
469
- min_width=100,
470
- scale=0,
471
- value=False,
472
- )
473
- reference_text = gr.Textbox(
474
- label="Reference Text",
475
- lines=1,
476
- placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
477
- value="",
478
- )
479
- with gr.Tab(label="Batch Inference"):
480
- batch_infer_num = gr.Slider(
481
- label="Batch infer nums",
482
- minimum=1,
483
- maximum=n_audios,
484
- step=1,
485
- value=1,
486
- )
 
487
 
488
  with gr.Column(scale=3):
489
- for _ in range(n_audios):
490
- with gr.Row():
491
- error = gr.HTML(
492
- label="Error Message",
493
- visible=True if _ == 0 else False,
494
- )
495
- global_error_list.append(error)
496
- with gr.Row():
497
- audio = gr.Audio(
498
- label="Generated Audio",
499
- type="numpy",
500
- interactive=False,
501
- visible=True if _ == 0 else False,
502
- )
503
- global_audio_list.append(audio)
504
-
505
  with gr.Row():
506
- stream_audio = gr.Audio(
507
- label="Streaming Audio",
508
- streaming=True,
509
- autoplay=True,
 
 
 
 
510
  interactive=False,
511
- show_download_button=True,
512
  )
 
513
  with gr.Row():
514
  with gr.Column(scale=3):
515
  generate = gr.Button(
516
- value="\U0001F3A7 " + "Generate", variant="primary"
517
- )
518
- generate_stream = gr.Button(
519
- value="\U0001F3A7 " + "Streaming Generate",
520
- variant="primary",
521
  )
522
 
523
  text.input(
524
- fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
525
  )
526
 
527
- if_load_asr_model.change(
528
- fn=change_if_load_asr_model,
529
- inputs=[if_load_asr_model],
530
- outputs=[if_load_asr_model],
531
- )
532
-
533
- if_auto_label.change(
534
- fn=lambda: gr.Textbox(value=""),
535
- inputs=[],
536
- outputs=[reference_text],
537
- ).then(
538
- fn=change_if_auto_label,
539
- inputs=[
540
- if_load_asr_model,
541
- if_auto_label,
542
- enable_reference_audio,
543
- reference_audio,
544
- reference_text,
545
- ],
546
- outputs=[reference_text],
547
- )
548
-
549
- def select_example_audio(audio_file):
550
- if audio_file:
551
- audio_path = os.path.join("examples", audio_file)
552
- lab_file = os.path.splitext(audio_file)[0] + ".lab"
553
- lab_path = os.path.join("examples", lab_file)
554
-
555
- if os.path.exists(lab_path):
556
- with open(lab_path, "r", encoding="utf-8") as f:
557
- lab_content = f.read().strip()
558
- else:
559
- lab_content = ""
560
-
561
- return audio_path, lab_content, True
562
- return None, "", False
563
-
564
- # Connect the dropdown to update reference audio and text
565
- example_audio_dropdown.change(
566
- fn=select_example_audio,
567
- inputs=[example_audio_dropdown],
568
- outputs=[reference_audio, reference_text, enable_reference_audio]
569
- )
570
- # # Submit
 
 
571
  generate.click(
572
  inference_wrapper,
573
  [
574
  refined_text,
575
- enable_reference_audio,
 
576
  reference_audio,
577
  reference_text,
578
  max_new_tokens,
@@ -580,26 +521,28 @@ def build_app():
580
  top_p,
581
  repetition_penalty,
582
  temperature,
583
- batch_infer_num,
584
- if_load_asr_model,
585
  ],
586
- [stream_audio, *global_audio_list, *global_error_list],
587
  concurrency_limit=1,
588
  )
 
589
  return app
590
 
591
 
 
592
  def parse_args():
593
  parser = ArgumentParser()
594
  parser.add_argument(
595
  "--llama-checkpoint-path",
596
  type=Path,
597
- default="checkpoints/fish-speech-1.4",
598
  )
599
  parser.add_argument(
600
  "--decoder-checkpoint-path",
601
  type=Path,
602
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
603
  )
604
  parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
605
  parser.add_argument("--device", type=str, default="cuda")
@@ -634,17 +577,20 @@ if __name__ == "__main__":
634
 
635
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
636
  list(
637
- inference(
638
- text="Hello, world!",
639
- enable_reference_audio=False,
640
- reference_audio=None,
641
- reference_text="",
642
- max_new_tokens=0,
643
- chunk_length=200,
644
- top_p=0.7,
645
- repetition_penalty=1.2,
646
- temperature=0.7,
647
- )
 
 
 
648
  )
649
 
650
  logger.info("Warming up done, launching the web UI...")
 
10
 
11
  # Download if not exists
12
  os.makedirs("checkpoints", exist_ok=True)
13
+ snapshot_download(repo_id="fishaudio/fish-speech-1.5", local_dir="./checkpoints/fish-speech-1.5")
14
 
15
  print("All checkpoints downloaded")
16
 
 
31
  from loguru import logger
32
  from transformers import AutoTokenizer
33
 
34
+ from fish_speech.i18n import i18n
 
35
  from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
36
+ from fish_speech.utils import autocast_exclude_mps, set_seed
37
  from tools.api import decode_vq_tokens, encode_reference
38
+ from tools.file import AUDIO_EXTENSIONS, list_files
39
  from tools.llama.generate import (
40
  GenerateRequest,
41
  GenerateResponse,
 
44
  )
45
  from tools.vqgan.inference import load_model as load_decoder_model
46
 
47
+ from tools.schema import (
48
+ GLOBAL_NUM_SAMPLES,
49
+ ASRPackRequest,
50
+ ServeASRRequest,
51
+ ServeASRResponse,
52
+ ServeASRSegment,
53
+ ServeAudioPart,
54
+ ServeForwardMessage,
55
+ ServeMessage,
56
+ ServeRequest,
57
+ ServeResponse,
58
+ ServeStreamDelta,
59
+ ServeStreamResponse,
60
+ ServeTextPart,
61
+ ServeTimedASRResponse,
62
+ ServeTTSRequest,
63
+ ServeVQGANDecodeRequest,
64
+ ServeVQGANDecodeResponse,
65
+ ServeVQGANEncodeRequest,
66
+ ServeVQGANEncodeResponse,
67
+ ServeVQPart,
68
+ ServeReferenceAudio
69
+ )
70
  # Make einx happy
71
  os.environ["EINX_FILTER_TRACEBACK"] = "false"
72
 
73
 
74
  HEADER_MD = """# Fish Speech
75
 
76
+ ## The demo in this space is version 1.5, Please check [Fish Audio](https://fish.audio) for the best model.
77
+ ## 该 Demo 为 Fish Speech 1.5 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
78
 
79
  A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
80
  由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
81
 
82
+ You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).
83
+ 你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.5) 找到模型.
84
 
85
  Related code and weights are released under CC BY-NC-SA 4.0 License.
86
  相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
 
88
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
89
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
90
 
91
+ The model running in this WebUI is Fish Speech V1.5 Medium.
92
+ 在此 WebUI 中运行的模型是 Fish Speech V1.5 Medium.
93
  """
94
 
95
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
 
118
 
119
  @GPU_DECORATOR
120
  @torch.inference_mode()
121
+ def inference(req: ServeTTSRequest):
122
+
123
+ global prompt_tokens, prompt_texts
124
+
125
+ idstr: str | None = req.reference_id
126
+ if idstr is not None:
127
+ ref_folder = Path("references") / idstr
128
+ ref_folder.mkdir(parents=True, exist_ok=True)
129
+ ref_audios = list_files(
130
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
 
 
 
 
 
 
 
 
 
131
  )
132
 
133
+ if req.use_memory_cache == "never" or (
134
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
135
+ ):
136
+ prompt_tokens = [
137
+ encode_reference(
138
+ decoder_model=decoder_model,
139
+ reference_audio=audio_to_bytes(str(ref_audio)),
140
+ enable_reference_audio=True,
141
+ )
142
+ for ref_audio in ref_audios
143
+ ]
144
+ prompt_texts = [
145
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
146
+ for ref_audio in ref_audios
147
+ ]
148
+ else:
149
+ logger.info("Use same references")
150
+
151
+ else:
152
+ # Parse reference audio aka prompt
153
+ refs = req.references
154
+
155
+ if req.use_memory_cache == "never" or (
156
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
157
+ ):
158
+ prompt_tokens = [
159
+ encode_reference(
160
+ decoder_model=decoder_model,
161
+ reference_audio=ref.audio,
162
+ enable_reference_audio=True,
163
+ )
164
+ for ref in refs
165
+ ]
166
+ prompt_texts = [ref.text for ref in refs]
167
+ else:
168
+ logger.info("Use same references")
169
+
170
+ if req.seed is not None:
171
+ set_seed(req.seed)
172
+ logger.warning(f"set seed: {req.seed}")
173
 
174
  # LLAMA Inference
175
  request = dict(
176
  device=decoder_model.device,
177
+ max_new_tokens=req.max_new_tokens,
178
+ text=(
179
+ req.text
180
+ if not req.normalize
181
+ else ChnNormedText(raw_text=req.text).normalize()
182
+ ),
183
+ top_p=req.top_p,
184
+ repetition_penalty=req.repetition_penalty,
185
+ temperature=req.temperature,
186
  compile=args.compile,
187
+ iterative_prompt=req.chunk_length > 0,
188
+ chunk_length=req.chunk_length,
189
+ max_length=4096,
190
+ prompt_tokens=prompt_tokens,
191
+ prompt_text=prompt_texts,
192
  )
193
 
194
  response_queue = queue.Queue()
 
204
  while True:
205
  result: WrappedGenerateResponse = response_queue.get()
206
  if result.status == "error":
207
+ yield None, None, build_html_error_message(result.response)
208
+ break
209
 
210
  result: GenerateResponse = result.response
211
  if result.action == "next":
212
  break
213
 
214
+ with autocast_exclude_mps(
215
+ device_type=decoder_model.device.type, dtype=args.precision
 
 
 
 
 
216
  ):
217
  fake_audios = decode_vq_tokens(
218
  decoder_model=decoder_model,
 
227
  None,
228
  None,
229
  build_html_error_message(
230
+ i18n("No audio generated, please check the input text.")
231
  ),
232
  )
233
 
234
+ # No matter streaming or not, we need to return the final audio
235
  audio = np.concatenate(segments, axis=0)
236
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
237
 
238
  if torch.cuda.is_available():
239
  torch.cuda.empty_cache()
240
  gc.collect()
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  n_audios = 4
243
 
244
  global_audio_list = []
245
  global_error_list = []
246
 
247
+
248
  def inference_wrapper(
249
  text,
250
  enable_reference_audio,
 
255
  top_p,
256
  repetition_penalty,
257
  temperature,
258
+ seed,
259
  batch_infer_num,
 
260
  ):
261
  audios = []
262
  errors = []
263
 
264
  for _ in range(batch_infer_num):
265
+ result = inference(
266
  text,
267
  enable_reference_audio,
268
  reference_audio,
 
272
  top_p,
273
  repetition_penalty,
274
  temperature,
275
+ seed,
276
  )
277
 
278
+ _, audio_data, error_message = next(result)
279
 
280
  audios.append(
281
  gr.Audio(value=audio_data if audio_data else None, visible=True),
 
307
  buffer.close()
308
  return wav_header_bytes
309
 
 
310
  def normalize_text(user_input, use_normalization):
311
  if use_normalization:
312
  return ChnNormedText(raw_text=user_input).normalize()
313
  else:
314
  return user_input
315
 
316
+ def update_examples():
317
+ examples_dir = Path("references")
318
+ examples_dir.mkdir(parents=True, exist_ok=True)
319
+ example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
320
+ return gr.Dropdown(choices=example_audios + [""])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  def build_app():
323
  with gr.Blocks(theme=gr.themes.Base()) as app:
 
335
  with gr.Row():
336
  with gr.Column(scale=3):
337
  text = gr.Textbox(
338
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
339
  )
340
  refined_text = gr.Textbox(
341
+ label=i18n("Realtime Transform Text"),
342
+ placeholder=i18n(
343
+ "Normalization Result Preview (Currently Only Chinese)"
344
+ ),
345
  lines=5,
346
  interactive=False,
347
  )
348
 
349
  with gr.Row():
350
+ normalize = gr.Checkbox(
351
+ label=i18n("Text Normalization"),
 
 
 
 
 
 
352
  value=False,
 
353
  )
354
 
355
  with gr.Row():
356
+ with gr.Column():
357
+ with gr.Tab(label=i18n("Advanced Config")):
358
+ with gr.Row():
359
+ chunk_length = gr.Slider(
360
+ label=i18n("Iterative Prompt Length, 0 means off"),
361
+ minimum=0,
362
+ maximum=300,
363
+ value=200,
364
+ step=8,
365
+ )
366
+
367
+ max_new_tokens = gr.Slider(
368
+ label=i18n(
369
+ "Maximum tokens per batch, 0 means no limit"
370
+ ),
371
+ minimum=0,
372
+ maximum=2048,
373
+ value=0,
374
+ step=8,
375
+ )
376
+
377
+ with gr.Row():
378
+ top_p = gr.Slider(
379
+ label="Top-P",
380
+ minimum=0.6,
381
+ maximum=0.9,
382
+ value=0.7,
383
+ step=0.01,
384
+ )
385
+
386
+ repetition_penalty = gr.Slider(
387
+ label=i18n("Repetition Penalty"),
388
+ minimum=1,
389
+ maximum=1.5,
390
+ value=1.2,
391
+ step=0.01,
392
+ )
393
+
394
+ with gr.Row():
395
+ temperature = gr.Slider(
396
+ label="Temperature",
397
+ minimum=0.6,
398
+ maximum=0.9,
399
+ value=0.7,
400
+ step=0.01,
401
+ )
402
+ seed = gr.Number(
403
+ label="Seed",
404
+ info="0 means randomized inference, otherwise deterministic",
405
+ value=0,
406
+ )
407
+
408
+ with gr.Tab(label=i18n("Reference Audio")):
409
+ with gr.Row():
410
+ gr.Markdown(
411
+ i18n(
412
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
413
+ )
414
+ )
415
+ with gr.Row():
416
+ reference_id = gr.Textbox(
417
+ label=i18n("Reference ID"),
418
+ placeholder="Leave empty to use uploaded references",
419
+ )
420
+
421
+ with gr.Row():
422
+ use_memory_cache = gr.Radio(
423
+ label=i18n("Use Memory Cache"),
424
+ choices=["never", "on-demand", "always"],
425
+ value="on-demand",
426
+ )
427
+
428
+ with gr.Row():
429
+ reference_audio = gr.Audio(
430
+ label=i18n("Reference Audio"),
431
+ type="filepath",
432
+ )
433
+ with gr.Row():
434
+ reference_text = gr.Textbox(
435
+ label=i18n("Reference Text"),
436
+ lines=1,
437
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
438
+ value="",
439
+ )
440
 
441
  with gr.Column(scale=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  with gr.Row():
443
+ error = gr.HTML(
444
+ label=i18n("Error Message"),
445
+ visible=True,
446
+ )
447
+ with gr.Row():
448
+ audio = gr.Audio(
449
+ label=i18n("Generated Audio"),
450
+ type="numpy",
451
  interactive=False,
452
+ visible=True,
453
  )
454
+
455
  with gr.Row():
456
  with gr.Column(scale=3):
457
  generate = gr.Button(
458
+ value="\U0001F3A7 " + i18n("Generate"), variant="primary"
 
 
 
 
459
  )
460
 
461
  text.input(
462
+ fn=normalize_text, inputs=[text, normalize], outputs=[refined_text]
463
  )
464
 
465
+ def inference_wrapper(
466
+ text,
467
+ normalize,
468
+ reference_id,
469
+ reference_audio,
470
+ reference_text,
471
+ max_new_tokens,
472
+ chunk_length,
473
+ top_p,
474
+ repetition_penalty,
475
+ temperature,
476
+ seed,
477
+ use_memory_cache,
478
+ ):
479
+ references = []
480
+ if reference_audio:
481
+ # 将文件路径转换为字节
482
+ with open(reference_audio, 'rb') as audio_file:
483
+ audio_bytes = audio_file.read()
484
+ references = [
485
+ ServeReferenceAudio(audio=audio_bytes, text=reference_text)
486
+ ]
487
+
488
+ req = ServeTTSRequest(
489
+ text=text,
490
+ normalize=normalize,
491
+ reference_id=reference_id if reference_id else None,
492
+ references=references,
493
+ max_new_tokens=max_new_tokens,
494
+ chunk_length=chunk_length,
495
+ top_p=top_p,
496
+ repetition_penalty=repetition_penalty,
497
+ temperature=temperature,
498
+ seed=int(seed) if seed else None,
499
+ use_memory_cache=use_memory_cache,
500
+ )
501
+
502
+ for result in inference(req):
503
+ if result[2]: # Error message
504
+ return None, result[2]
505
+ elif result[1]: # Audio data
506
+ return result[1], None
507
+
508
+ return None, i18n("No audio generated")
509
+
510
+ # Submit
511
  generate.click(
512
  inference_wrapper,
513
  [
514
  refined_text,
515
+ normalize,
516
+ reference_id,
517
  reference_audio,
518
  reference_text,
519
  max_new_tokens,
 
521
  top_p,
522
  repetition_penalty,
523
  temperature,
524
+ seed,
525
+ use_memory_cache,
526
  ],
527
+ [audio, error],
528
  concurrency_limit=1,
529
  )
530
+
531
  return app
532
 
533
 
534
+
535
  def parse_args():
536
  parser = ArgumentParser()
537
  parser.add_argument(
538
  "--llama-checkpoint-path",
539
  type=Path,
540
+ default="checkpoints/fish-speech-1.5",
541
  )
542
  parser.add_argument(
543
  "--decoder-checkpoint-path",
544
  type=Path,
545
+ default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
546
  )
547
  parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
548
  parser.add_argument("--device", type=str, default="cuda")
 
577
 
578
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
579
  list(
580
+ inference(
581
+ ServeTTSRequest(
582
+ text="Hello world.",
583
+ references=[],
584
+ reference_id=None,
585
+ max_new_tokens=0,
586
+ chunk_length=200,
587
+ top_p=0.7,
588
+ repetition_penalty=1.5,
589
+ temperature=0.7,
590
+ emotion=None,
591
+ format="wav",
592
+ )
593
+ )
594
  )
595
 
596
  logger.info("Warming up done, launching the web UI...")
fish_speech/callbacks/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from .grad_norm import GradNormMonitor
2
-
3
- __all__ = ["GradNormMonitor"]
 
1
+ from .grad_norm import GradNormMonitor
2
+
3
+ __all__ = ["GradNormMonitor"]
fish_speech/callbacks/grad_norm.py CHANGED
@@ -1,113 +1,113 @@
1
- from typing import Optional, Union
2
-
3
- import lightning.pytorch as pl
4
- import torch
5
- from lightning import LightningModule, Trainer
6
- from lightning.pytorch.callbacks import Callback
7
- from torch import Tensor, nn
8
- from torch.utils._foreach_utils import (
9
- _group_tensors_by_device_and_dtype,
10
- _has_foreach_support,
11
- )
12
-
13
-
14
- @torch.no_grad()
15
- def grad_norm(
16
- parameters: Union[Tensor, list[Tensor]],
17
- norm_type: float = 2.0,
18
- ) -> float:
19
- """
20
- Returns the norm of the gradients of the given parameters.
21
-
22
- Args:
23
- parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
24
- single Tensor that will have gradients normalized
25
- norm_type (float): type of the used p-norm.
26
-
27
- Returns:
28
- Total norm of the parameter gradients (viewed as a single vector).
29
- """ # noqa: E501
30
-
31
- if isinstance(parameters, Tensor):
32
- parameters = [parameters]
33
-
34
- grads = [p.grad for p in parameters if p.grad is not None]
35
- if len(grads) == 0:
36
- return None
37
-
38
- first_device = grads[0].device
39
- grouped_grads: dict[
40
- tuple[torch.device, torch.dtype], list[list[Tensor]]
41
- ] = _group_tensors_by_device_and_dtype(
42
- [[g.detach() for g in grads]]
43
- ) # type: ignore[assignment]
44
-
45
- norms = []
46
- for (device, _), ([grads], _) in grouped_grads.items():
47
- if _has_foreach_support(grads, device=device):
48
- norms.extend(torch._foreach_norm(grads, norm_type))
49
- else:
50
- norms.extend([torch.norm(g, norm_type) for g in grads])
51
-
52
- return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
53
-
54
-
55
- class GradNormMonitor(Callback):
56
- """
57
- Callback that computes the gradient norm of the model parameters.
58
- """
59
-
60
- def __init__(
61
- self,
62
- norm_type: float = 2.0,
63
- logging_interval: str = "step",
64
- sub_module: Optional[Union[str, list[str]]] = None,
65
- ) -> None:
66
- """
67
- Args:
68
- norm_type (float): type of the used p-norm.
69
- logging_interval (str): "step" or "epoch".
70
- """
71
- super().__init__()
72
-
73
- self.norm_type = norm_type
74
- self.logging_interval = logging_interval
75
- self.sub_module = sub_module
76
-
77
- def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
78
- """
79
- Computes the gradient norm of the model parameters and logs it to the logger.
80
-
81
- Args:
82
- trainer (Trainer): The trainer object
83
- model (LightningModule): The current lightningModule
84
- """
85
-
86
- lightning_model = model
87
-
88
- if self.sub_module is None:
89
- return self.log_sub_module_grad_norm(lightning_model, model, "")
90
-
91
- sub_modules = self.sub_module
92
- if isinstance(sub_modules, str):
93
- sub_modules = [sub_modules]
94
-
95
- for sub_module in sub_modules:
96
- self.log_sub_module_grad_norm(
97
- lightning_model, getattr(model, sub_module), f"/{sub_module}"
98
- )
99
-
100
- def log_sub_module_grad_norm(
101
- self, lightning_model: LightningModule, model: nn.Module, path: str
102
- ) -> None:
103
- grad_norm_val = grad_norm(model.parameters(), self.norm_type)
104
- if grad_norm_val is None:
105
- return
106
-
107
- on_step = self.logging_interval == "step"
108
- lightning_model.log(
109
- f"train{path}/grad_norm",
110
- grad_norm_val,
111
- on_step=on_step,
112
- on_epoch=not on_step,
113
- )
 
1
+ from typing import Optional, Union
2
+
3
+ import lightning.pytorch as pl
4
+ import torch
5
+ from lightning import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks import Callback
7
+ from torch import Tensor, nn
8
+ from torch.utils._foreach_utils import (
9
+ _group_tensors_by_device_and_dtype,
10
+ _has_foreach_support,
11
+ )
12
+
13
+
14
+ @torch.no_grad()
15
+ def grad_norm(
16
+ parameters: Union[Tensor, list[Tensor]],
17
+ norm_type: float = 2.0,
18
+ ) -> float:
19
+ """
20
+ Returns the norm of the gradients of the given parameters.
21
+
22
+ Args:
23
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
24
+ single Tensor that will have gradients normalized
25
+ norm_type (float): type of the used p-norm.
26
+
27
+ Returns:
28
+ Total norm of the parameter gradients (viewed as a single vector).
29
+ """ # noqa: E501
30
+
31
+ if isinstance(parameters, Tensor):
32
+ parameters = [parameters]
33
+
34
+ grads = [p.grad for p in parameters if p.grad is not None]
35
+ if len(grads) == 0:
36
+ return None
37
+
38
+ first_device = grads[0].device
39
+ grouped_grads: dict[
40
+ tuple[torch.device, torch.dtype], list[list[Tensor]]
41
+ ] = _group_tensors_by_device_and_dtype(
42
+ [[g.detach() for g in grads]]
43
+ ) # type: ignore[assignment]
44
+
45
+ norms = []
46
+ for (device, _), ([grads], _) in grouped_grads.items():
47
+ if _has_foreach_support(grads, device=device):
48
+ norms.extend(torch._foreach_norm(grads, norm_type))
49
+ else:
50
+ norms.extend([torch.norm(g, norm_type) for g in grads])
51
+
52
+ return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
53
+
54
+
55
+ class GradNormMonitor(Callback):
56
+ """
57
+ Callback that computes the gradient norm of the model parameters.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ norm_type: float = 2.0,
63
+ logging_interval: str = "step",
64
+ sub_module: Optional[Union[str, list[str]]] = None,
65
+ ) -> None:
66
+ """
67
+ Args:
68
+ norm_type (float): type of the used p-norm.
69
+ logging_interval (str): "step" or "epoch".
70
+ """
71
+ super().__init__()
72
+
73
+ self.norm_type = norm_type
74
+ self.logging_interval = logging_interval
75
+ self.sub_module = sub_module
76
+
77
+ def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
78
+ """
79
+ Computes the gradient norm of the model parameters and logs it to the logger.
80
+
81
+ Args:
82
+ trainer (Trainer): The trainer object
83
+ model (LightningModule): The current lightningModule
84
+ """
85
+
86
+ lightning_model = model
87
+
88
+ if self.sub_module is None:
89
+ return self.log_sub_module_grad_norm(lightning_model, model, "")
90
+
91
+ sub_modules = self.sub_module
92
+ if isinstance(sub_modules, str):
93
+ sub_modules = [sub_modules]
94
+
95
+ for sub_module in sub_modules:
96
+ self.log_sub_module_grad_norm(
97
+ lightning_model, getattr(model, sub_module), f"/{sub_module}"
98
+ )
99
+
100
+ def log_sub_module_grad_norm(
101
+ self, lightning_model: LightningModule, model: nn.Module, path: str
102
+ ) -> None:
103
+ grad_norm_val = grad_norm(model.parameters(), self.norm_type)
104
+ if grad_norm_val is None:
105
+ return
106
+
107
+ on_step = self.logging_interval == "step"
108
+ lightning_model.log(
109
+ f"train{path}/grad_norm",
110
+ grad_norm_val,
111
+ on_step=on_step,
112
+ on_epoch=not on_step,
113
+ )
fish_speech/configs/base.yaml CHANGED
@@ -1,87 +1,87 @@
1
- # Base configuration for training a model
2
- paths:
3
- run_dir: results/${project}
4
- ckpt_dir: ${paths.run_dir}/checkpoints
5
-
6
- hydra:
7
- run:
8
- dir: ${paths.run_dir}
9
-
10
- # Lightning Trainer
11
- trainer:
12
- _target_: lightning.pytorch.trainer.Trainer
13
-
14
- default_root_dir: ${paths.run_dir}
15
- accelerator: gpu
16
- num_nodes: 1
17
- devices: auto
18
- strategy:
19
- _target_: lightning.pytorch.strategies.DDPStrategy
20
- process_group_backend: nccl # This should be override when training on windows
21
-
22
- precision: bf16-mixed
23
-
24
- # disable validation by epoch end
25
- check_val_every_n_epoch: null
26
- val_check_interval: 5000
27
- max_steps: 100_000
28
-
29
- # Use torch.backends.cudnn.benchmark to speed up training
30
- benchmark: true
31
-
32
- # Callbacks
33
- callbacks:
34
- model_checkpoint:
35
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
36
- dirpath: ${paths.ckpt_dir}
37
- filename: "step_{step:09d}"
38
- save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
39
- save_top_k: 5 # save 5 latest checkpoints
40
- monitor: step # use step to monitor checkpoints
41
- mode: max # save the latest checkpoint with the highest global_step
42
- every_n_epochs: null # don't save checkpoints by epoch end
43
- every_n_train_steps: 5000 # save checkpoints every 5000 steps
44
- auto_insert_metric_name: false
45
-
46
- model_summary:
47
- _target_: lightning.pytorch.callbacks.ModelSummary
48
- max_depth: 2 # the maximum depth of layer nesting that the summary will include
49
-
50
- learning_rate_monitor:
51
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
52
- logging_interval: step
53
- log_momentum: false
54
-
55
- grad_norm_monitor:
56
- _target_: fish_speech.callbacks.GradNormMonitor
57
- norm_type: 2
58
- logging_interval: step
59
-
60
- # Logger
61
- logger:
62
- tensorboard:
63
- _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
64
- save_dir: "${paths.run_dir}/tensorboard/"
65
- name: null
66
- log_graph: false
67
- default_hp_metric: true
68
- prefix: ""
69
-
70
- # wandb:
71
- # _target_: lightning.pytorch.loggers.wandb.WandbLogger
72
- # # name: "" # name of the run (normally generated by wandb)
73
- # save_dir: "${paths.run_dir}"
74
- # offline: False
75
- # id: null # pass correct id to resume experiment!
76
- # anonymous: null # enable anonymous logging
77
- # project: "fish-speech"
78
- # log_model: False # upload lightning ckpts
79
- # prefix: "" # a string to put at the beginning of metric keys
80
- # # entity: "" # set to name of your wandb team
81
- # group: ""
82
- # tags: ["vq", "hq", "finetune"]
83
- # job_type: ""
84
-
85
- # Loop
86
- train: true
87
- test: false
 
1
+ # Base configuration for training a model
2
+ paths:
3
+ run_dir: results/${project}
4
+ ckpt_dir: ${paths.run_dir}/checkpoints
5
+
6
+ hydra:
7
+ run:
8
+ dir: ${paths.run_dir}
9
+
10
+ # Lightning Trainer
11
+ trainer:
12
+ _target_: lightning.pytorch.trainer.Trainer
13
+
14
+ default_root_dir: ${paths.run_dir}
15
+ accelerator: gpu
16
+ num_nodes: 1
17
+ devices: auto
18
+ strategy:
19
+ _target_: lightning.pytorch.strategies.DDPStrategy
20
+ process_group_backend: nccl # This should be override when training on windows
21
+
22
+ precision: bf16-mixed
23
+
24
+ # disable validation by epoch end
25
+ check_val_every_n_epoch: null
26
+ val_check_interval: 5000
27
+ max_steps: 100_000
28
+
29
+ # Use torch.backends.cudnn.benchmark to speed up training
30
+ benchmark: true
31
+
32
+ # Callbacks
33
+ callbacks:
34
+ model_checkpoint:
35
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
36
+ dirpath: ${paths.ckpt_dir}
37
+ filename: "step_{step:09d}"
38
+ save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
39
+ save_top_k: 5 # save 5 latest checkpoints
40
+ monitor: step # use step to monitor checkpoints
41
+ mode: max # save the latest checkpoint with the highest global_step
42
+ every_n_epochs: null # don't save checkpoints by epoch end
43
+ every_n_train_steps: 5000 # save checkpoints every 5000 steps
44
+ auto_insert_metric_name: false
45
+
46
+ model_summary:
47
+ _target_: lightning.pytorch.callbacks.ModelSummary
48
+ max_depth: 2 # the maximum depth of layer nesting that the summary will include
49
+
50
+ learning_rate_monitor:
51
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
52
+ logging_interval: step
53
+ log_momentum: false
54
+
55
+ grad_norm_monitor:
56
+ _target_: fish_speech.callbacks.GradNormMonitor
57
+ norm_type: 2
58
+ logging_interval: step
59
+
60
+ # Logger
61
+ logger:
62
+ tensorboard:
63
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
64
+ save_dir: "${paths.run_dir}/tensorboard/"
65
+ name: null
66
+ log_graph: false
67
+ default_hp_metric: true
68
+ prefix: ""
69
+
70
+ # wandb:
71
+ # _target_: lightning.pytorch.loggers.wandb.WandbLogger
72
+ # # name: "" # name of the run (normally generated by wandb)
73
+ # save_dir: "${paths.run_dir}"
74
+ # offline: False
75
+ # id: null # pass correct id to resume experiment!
76
+ # anonymous: null # enable anonymous logging
77
+ # project: "fish-speech"
78
+ # log_model: False # upload lightning ckpts
79
+ # prefix: "" # a string to put at the beginning of metric keys
80
+ # # entity: "" # set to name of your wandb team
81
+ # group: ""
82
+ # tags: ["vq", "hq", "finetune"]
83
+ # job_type: ""
84
+
85
+ # Loop
86
+ train: true
87
+ test: false
fish_speech/configs/firefly_gan_vq.yaml CHANGED
@@ -1,33 +1,33 @@
1
- _target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
2
- spec_transform:
3
- _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
4
- sample_rate: 44100
5
- n_mels: 160
6
- n_fft: 2048
7
- hop_length: 512
8
- win_length: 2048
9
- backbone:
10
- _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
11
- input_channels: 160
12
- depths: [3, 3, 9, 3]
13
- dims: [128, 256, 384, 512]
14
- drop_path_rate: 0.2
15
- kernel_size: 7
16
- head:
17
- _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
18
- hop_length: 512
19
- upsample_rates: [8, 8, 2, 2, 2] # aka. strides
20
- upsample_kernel_sizes: [16, 16, 4, 4, 4]
21
- resblock_kernel_sizes: [3, 7, 11]
22
- resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
23
- num_mels: 512
24
- upsample_initial_channel: 512
25
- pre_conv_kernel_size: 13
26
- post_conv_kernel_size: 13
27
- quantizer:
28
- _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
29
- input_dim: 512
30
- n_groups: 8
31
- n_codebooks: 1
32
- levels: [8, 5, 5, 5]
33
- downsample_factor: [2, 2]
 
1
+ _target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
2
+ spec_transform:
3
+ _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
4
+ sample_rate: 44100
5
+ n_mels: 160
6
+ n_fft: 2048
7
+ hop_length: 512
8
+ win_length: 2048
9
+ backbone:
10
+ _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
11
+ input_channels: 160
12
+ depths: [3, 3, 9, 3]
13
+ dims: [128, 256, 384, 512]
14
+ drop_path_rate: 0.2
15
+ kernel_size: 7
16
+ head:
17
+ _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
18
+ hop_length: 512
19
+ upsample_rates: [8, 8, 2, 2, 2] # aka. strides
20
+ upsample_kernel_sizes: [16, 16, 4, 4, 4]
21
+ resblock_kernel_sizes: [3, 7, 11]
22
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
23
+ num_mels: 512
24
+ upsample_initial_channel: 512
25
+ pre_conv_kernel_size: 13
26
+ post_conv_kernel_size: 13
27
+ quantizer:
28
+ _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
29
+ input_dim: 512
30
+ n_groups: 8
31
+ n_codebooks: 1
32
+ levels: [8, 5, 5, 5]
33
+ downsample_factor: [2, 2]
fish_speech/configs/lora/r_8_alpha_16.yaml CHANGED
@@ -1,4 +1,4 @@
1
- _target_: fish_speech.models.text2semantic.lora.LoraConfig
2
- r: 8
3
- lora_alpha: 16
4
- lora_dropout: 0.01
 
1
+ _target_: fish_speech.models.text2semantic.lora.LoraConfig
2
+ r: 8
3
+ lora_alpha: 16
4
+ lora_dropout: 0.01
fish_speech/configs/model/dual_ar_2_codebook_large.yaml DELETED
@@ -1,9 +0,0 @@
1
- defaults:
2
- - dual_ar_2_codebook_small
3
- - _self_
4
-
5
- config:
6
- n_layer: 30
7
- n_fast_layer: 6
8
- n_head: 24
9
- dim: 1536
 
 
 
 
 
 
 
 
 
 
fish_speech/configs/model/dual_ar_2_codebook_medium.yaml DELETED
@@ -1,9 +0,0 @@
1
- defaults:
2
- - dual_ar_2_codebook_small
3
- - _self_
4
-
5
- config:
6
- n_layer: 24
7
- n_fast_layer: 6
8
- n_head: 16
9
- dim: 1024
 
 
 
 
 
 
 
 
 
 
fish_speech/configs/model/dual_ar_2_codebook_small.yaml DELETED
@@ -1,13 +0,0 @@
1
- _target_: fish_speech.models.text2semantic.llama.DualARTransformer
2
- config:
3
- _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
4
- max_seq_len: ${max_length}
5
- vocab_size: 264 # pad 262 to 8x
6
- n_layer: 12
7
- n_fast_layer: 4
8
- n_head: 12
9
- dim: 768
10
- rope_base: 10000
11
- norm_eps: 1e-5
12
- num_codebooks: 2 # input/output codebook size
13
- codebook_size: 1032 # codebook size 1024 + 2 special tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/configs/model/naive_2_codebook_small.yaml DELETED
@@ -1,12 +0,0 @@
1
- _target_: fish_speech.models.text2semantic.llama.NaiveTransformer
2
- config:
3
- _target_: fish_speech.models.text2semantic.llama.NaiveModelArgs
4
- max_seq_len: ${max_length}
5
- vocab_size: 36408
6
- n_layer: 12
7
- n_head: 12
8
- dim: 768
9
- rope_base: 10000
10
- norm_eps: 1e-5
11
- num_codebooks: 2 # input/output codebook size
12
- codebook_size: 1032 # codebook size 1024 + 2 special tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/configs/text2semantic_finetune.yaml CHANGED
@@ -1,83 +1,83 @@
1
- defaults:
2
- - base
3
- - _self_
4
-
5
- project: text2semantic_finetune_dual_ar
6
- max_length: 4096
7
- pretrained_ckpt_path: checkpoints/fish-speech-1.4
8
-
9
- # Lightning Trainer
10
- trainer:
11
- accumulate_grad_batches: 1
12
- gradient_clip_val: 1.0
13
- gradient_clip_algorithm: "norm"
14
- max_steps: 1000
15
- precision: bf16-true
16
- limit_val_batches: 10
17
- val_check_interval: 100
18
-
19
- # Dataset Configuration
20
- tokenizer:
21
- _target_: transformers.AutoTokenizer.from_pretrained
22
- pretrained_model_name_or_path: ${pretrained_ckpt_path}
23
-
24
- # Dataset Configuration
25
- train_dataset:
26
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
27
- proto_files:
28
- - data/protos
29
- tokenizer: ${tokenizer}
30
- causal: true
31
- max_length: ${max_length}
32
- use_speaker: false
33
- interactive_prob: 0.7
34
-
35
- val_dataset:
36
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
37
- proto_files:
38
- - data/protos
39
- tokenizer: ${tokenizer}
40
- causal: true
41
- max_length: ${max_length}
42
- use_speaker: false
43
- interactive_prob: 0.7
44
-
45
- data:
46
- _target_: fish_speech.datasets.semantic.SemanticDataModule
47
- train_dataset: ${train_dataset}
48
- val_dataset: ${val_dataset}
49
- num_workers: 4
50
- batch_size: 8
51
- tokenizer: ${tokenizer}
52
- max_length: ${max_length}
53
-
54
- # Model Configuration
55
- model:
56
- _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
57
- model:
58
- _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
59
- path: ${pretrained_ckpt_path}
60
- load_weights: true
61
- max_length: ${max_length}
62
- lora_config: null
63
-
64
- optimizer:
65
- _target_: torch.optim.AdamW
66
- _partial_: true
67
- lr: 1e-4
68
- weight_decay: 0
69
- betas: [0.9, 0.95]
70
- eps: 1e-5
71
-
72
- lr_scheduler:
73
- _target_: torch.optim.lr_scheduler.LambdaLR
74
- _partial_: true
75
- lr_lambda:
76
- _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
77
- _partial_: true
78
- num_warmup_steps: 10
79
-
80
- # Callbacks
81
- callbacks:
82
- model_checkpoint:
83
- every_n_train_steps: ${trainer.val_check_interval}
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ project: text2semantic_finetune_dual_ar
6
+ max_length: 4096
7
+ pretrained_ckpt_path: checkpoints/fish-speech-1.4
8
+
9
+ # Lightning Trainer
10
+ trainer:
11
+ accumulate_grad_batches: 1
12
+ gradient_clip_val: 1.0
13
+ gradient_clip_algorithm: "norm"
14
+ max_steps: 1000
15
+ precision: bf16-true
16
+ limit_val_batches: 10
17
+ val_check_interval: 100
18
+
19
+ # Dataset Configuration
20
+ tokenizer:
21
+ _target_: transformers.AutoTokenizer.from_pretrained
22
+ pretrained_model_name_or_path: ${pretrained_ckpt_path}
23
+
24
+ # Dataset Configuration
25
+ train_dataset:
26
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
27
+ proto_files:
28
+ - data/protos
29
+ tokenizer: ${tokenizer}
30
+ causal: true
31
+ max_length: ${max_length}
32
+ use_speaker: false
33
+ interactive_prob: 0.7
34
+
35
+ val_dataset:
36
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
37
+ proto_files:
38
+ - data/protos
39
+ tokenizer: ${tokenizer}
40
+ causal: true
41
+ max_length: ${max_length}
42
+ use_speaker: false
43
+ interactive_prob: 0.7
44
+
45
+ data:
46
+ _target_: fish_speech.datasets.semantic.SemanticDataModule
47
+ train_dataset: ${train_dataset}
48
+ val_dataset: ${val_dataset}
49
+ num_workers: 4
50
+ batch_size: 8
51
+ tokenizer: ${tokenizer}
52
+ max_length: ${max_length}
53
+
54
+ # Model Configuration
55
+ model:
56
+ _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
57
+ model:
58
+ _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
59
+ path: ${pretrained_ckpt_path}
60
+ load_weights: true
61
+ max_length: ${max_length}
62
+ lora_config: null
63
+
64
+ optimizer:
65
+ _target_: torch.optim.AdamW
66
+ _partial_: true
67
+ lr: 1e-4
68
+ weight_decay: 0
69
+ betas: [0.9, 0.95]
70
+ eps: 1e-5
71
+
72
+ lr_scheduler:
73
+ _target_: torch.optim.lr_scheduler.LambdaLR
74
+ _partial_: true
75
+ lr_lambda:
76
+ _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
77
+ _partial_: true
78
+ num_warmup_steps: 10
79
+
80
+ # Callbacks
81
+ callbacks:
82
+ model_checkpoint:
83
+ every_n_train_steps: ${trainer.val_check_interval}
fish_speech/configs/text2semantic_finetune_lora.yaml DELETED
@@ -1,13 +0,0 @@
1
- defaults:
2
- - text2semantic_finetune
3
- - _self_
4
-
5
- project: text2semantic_finetune_dual_ar_lora
6
-
7
- # Model Configuration
8
- model:
9
- save_lora_only: true
10
- lora_config:
11
- _target_: fish_speech.models.text2semantic.lit_module.LoraConfig
12
- r: 8
13
- lora_alpha: 16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/configs/text2semantic_pretrain.yaml DELETED
@@ -1,74 +0,0 @@
1
- defaults:
2
- - base
3
- - [email protected]: dual_ar_2_codebook_small
4
- - _self_
5
-
6
- project: text2semantic_pretrain_dual_ar_debug
7
- max_length: 2048
8
-
9
- # Lightning Trainer
10
- trainer:
11
- accumulate_grad_batches: 1
12
- gradient_clip_val: 1.0
13
- gradient_clip_algorithm: 'norm'
14
- max_steps: 1_000_000
15
- precision: bf16-true
16
- limit_val_batches: 10
17
-
18
- # Dataset Configuration
19
- tokenizer:
20
- _target_: transformers.AutoTokenizer.from_pretrained
21
- pretrained_model_name_or_path: fishaudio/fish-speech-1
22
-
23
- # Dataset Configuration
24
- train_dataset:
25
- _target_: fish_speech.datasets.text.AutoAugTextDataset
26
- proto_files:
27
- - data/protos/train
28
- tokenizer: ${tokenizer}
29
- max_length: ${max_length}
30
- num_codebooks: ${model.model.config.num_codebooks}
31
- use_speaker: false
32
- interactive_prob: 0.5
33
-
34
- val_dataset:
35
- _target_: fish_speech.datasets.text.AutoAugTextDataset
36
- proto_files:
37
- - data/protos/test
38
- tokenizer: ${tokenizer}
39
- max_length: ${max_length}
40
- num_codebooks: ${model.model.config.num_codebooks}
41
- use_speaker: false
42
- interactive_prob: 0.5
43
-
44
- data:
45
- _target_: fish_speech.datasets.text.TextDataModule
46
- train_dataset: ${train_dataset}
47
- val_dataset: ${val_dataset}
48
- num_workers: 4
49
- batch_size: 8
50
- tokenizer: ${tokenizer}
51
- max_length: ${max_length}
52
-
53
- # Model Configuration
54
- model:
55
- _target_: fish_speech.models.text2semantic.TextToSemantic
56
- model: {}
57
-
58
- optimizer:
59
- _target_: torch.optim.AdamW
60
- _partial_: true
61
- lr: 3e-4
62
- weight_decay: 0.01
63
- betas: [0.9, 0.95]
64
- eps: 1e-5
65
-
66
- lr_scheduler:
67
- _target_: torch.optim.lr_scheduler.LambdaLR
68
- _partial_: true
69
- lr_lambda:
70
- _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
71
- _partial_: true
72
- num_warmup_steps: 2000
73
- num_training_steps: ${trainer.max_steps}
74
- final_lr_ratio: 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/configs/text2semantic_sft.yaml DELETED
@@ -1,87 +0,0 @@
1
- defaults:
2
- - base
3
- - [email protected]: dual_ar_8_codebook_small
4
- - _self_
5
-
6
- project: text2semantic_sft_medium_dual_ar
7
- max_length: 4096
8
- ckpt_path: results/text2semantic_pretrain_medium_dual_ar/checkpoints/step_000060000.ckpt
9
- resume_weights_only: true
10
-
11
- # Lightning Trainer
12
- trainer:
13
- accumulate_grad_batches: 1
14
- gradient_clip_val: 1.0
15
- gradient_clip_algorithm: 'norm'
16
- max_steps: 10_000
17
- precision: bf16-true
18
- limit_val_batches: 10
19
- val_check_interval: 500
20
-
21
- # Dataset Configuration
22
- tokenizer:
23
- _target_: transformers.AutoTokenizer.from_pretrained
24
- pretrained_model_name_or_path: fishaudio/speech-lm-v1
25
-
26
- # Dataset Configuration
27
- train_dataset:
28
- _target_: fish_speech.datasets.text.AutoAugTextDataset
29
- use_data_server: false
30
- proto_files:
31
- - data/protos/sft/train_Genshin.protos
32
- - data/protos/sft/sft.protos
33
- tokenizer: ${tokenizer}
34
- max_length: ${max_length}
35
- num_codebooks: ${model.model.config.num_codebooks}
36
- use_speaker: false
37
- phones_prob: 0.5
38
- interactive_prob: 0.5
39
-
40
- val_dataset:
41
- _target_: fish_speech.datasets.text.AutoAugTextDataset
42
- use_data_server: false
43
- proto_files:
44
- - data/protos/sft/val_Genshin.protos
45
- tokenizer: ${tokenizer}
46
- max_length: ${max_length}
47
- num_codebooks: ${model.model.config.num_codebooks}
48
- use_speaker: false
49
- phones_prob: 0.5
50
- interactive_prob: 0.5
51
-
52
- data:
53
- _target_: fish_speech.datasets.text.TextDataModule
54
- train_dataset: ${train_dataset}
55
- val_dataset: ${val_dataset}
56
- num_workers: 4
57
- batch_size: 8
58
- tokenizer: ${tokenizer}
59
- max_length: ${max_length}
60
-
61
- # Model Configuration
62
- model:
63
- _target_: fish_speech.models.text2semantic.TextToSemantic
64
- model: {}
65
-
66
- optimizer:
67
- _target_: torch.optim.AdamW
68
- _partial_: true
69
- lr: 4e-5
70
- weight_decay: 0
71
- betas: [0.9, 0.95]
72
- eps: 1e-5
73
-
74
- lr_scheduler:
75
- _target_: torch.optim.lr_scheduler.LambdaLR
76
- _partial_: true
77
- lr_lambda:
78
- _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
79
- _partial_: true
80
- num_warmup_steps: 100
81
- num_training_steps: ${trainer.max_steps}
82
- final_lr_ratio: 0
83
-
84
- callbacks:
85
- model_checkpoint:
86
- every_n_train_steps: 1000
87
- save_top_k: 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/configs/vqgan_finetune.yaml DELETED
@@ -1,135 +0,0 @@
1
- defaults:
2
- - base
3
- - _self_
4
-
5
- project: vq-gan-finetune
6
- ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth
7
- resume_weights_only: true
8
-
9
- # Lightning Trainer
10
- trainer:
11
- accelerator: gpu
12
- devices: auto
13
- precision: bf16-mixed
14
- max_steps: 100_000
15
- val_check_interval: 5000
16
- strategy: ddp_find_unused_parameters_true
17
-
18
- sample_rate: 44100
19
- hop_length: 512
20
- num_mels: 128
21
- n_fft: 2048
22
- win_length: 2048
23
- freeze_encoder: true
24
-
25
- # Dataset Configuration
26
- train_dataset:
27
- _target_: fish_speech.datasets.vqgan.VQGANDataset
28
- filelist: data/filelist.train.txt
29
- sample_rate: ${sample_rate}
30
- hop_length: ${hop_length}
31
- slice_frames: 512
32
-
33
- val_dataset:
34
- _target_: fish_speech.datasets.vqgan.VQGANDataset
35
- filelist: data/filelist.val.txt
36
- sample_rate: ${sample_rate}
37
- hop_length: ${hop_length}
38
-
39
- data:
40
- _target_: fish_speech.datasets.vqgan.VQGANDataModule
41
- train_dataset: ${train_dataset}
42
- val_dataset: ${val_dataset}
43
- num_workers: 4
44
- batch_size: 16
45
- val_batch_size: 16
46
-
47
- # Model Configuration
48
- model:
49
- _target_: fish_speech.models.vqgan.VQGAN
50
-
51
- sampling_rate: ${sample_rate}
52
- weight_adv: 0.2
53
- weight_vq: 1.0
54
- weight_mel: 1.0
55
- freeze_encoder: false
56
-
57
- encoder:
58
- _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
59
- input_channels: ${num_mels}
60
- residual_channels: 768
61
- residual_layers: 20
62
- dilation_cycle: 4
63
-
64
- quantizer:
65
- _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
66
- input_dim: 768
67
- n_codebooks: 1
68
- n_groups: 2
69
- levels: [8, 5, 5, 5]
70
-
71
- decoder:
72
- _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
73
- output_channels: ${num_mels}
74
- residual_channels: 768
75
- residual_layers: 20
76
- dilation_cycle: 4
77
- condition_channels: 768
78
-
79
- discriminator:
80
- _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
81
-
82
- vocoder:
83
- _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
84
- ckpt_path: null # You may download the pretrained vocoder and set the path here
85
-
86
- encode_mel_transform:
87
- _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
88
- sample_rate: ${sample_rate}
89
- n_fft: ${n_fft}
90
- hop_length: ${hop_length}
91
- win_length: ${win_length}
92
- n_mels: ${num_mels}
93
- f_min: 0.0
94
- f_max: 8000.0
95
-
96
- gt_mel_transform:
97
- _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
98
- sample_rate: ${sample_rate}
99
- n_fft: ${n_fft}
100
- hop_length: ${hop_length}
101
- win_length: ${win_length}
102
- n_mels: ${num_mels}
103
-
104
- optimizer:
105
- _target_: torch.optim.AdamW
106
- _partial_: true
107
- lr: 4e-5
108
- betas: [0.8, 0.99]
109
- eps: 1e-5
110
- weight_decay: 0.01
111
-
112
- lr_scheduler:
113
- _target_: torch.optim.lr_scheduler.LambdaLR
114
- _partial_: true
115
- lr_lambda:
116
- _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
117
- _partial_: true
118
- num_warmup_steps: 100
119
- num_training_steps: ${trainer.max_steps}
120
- final_lr_ratio: 0
121
-
122
- callbacks:
123
- model_summary:
124
- _target_: lightning.pytorch.callbacks.ModelSummary
125
- max_depth: 1
126
-
127
- model_checkpoint:
128
- every_n_train_steps: ${trainer.val_check_interval}
129
-
130
- grad_norm_monitor:
131
- sub_module:
132
- - encoder
133
- - decoder
134
- - quantizer
135
- - discriminator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/configs/vqgan_pretrain.yaml DELETED
@@ -1,139 +0,0 @@
1
- defaults:
2
- - base
3
- - _self_
4
-
5
- project: vq-gan-pretrain
6
-
7
- # Lightning Trainer
8
- trainer:
9
- accelerator: gpu
10
- devices: auto
11
- precision: bf16-mixed
12
- max_steps: 1_000_000
13
- val_check_interval: 5000
14
- strategy: ddp_find_unused_parameters_true
15
-
16
- sample_rate: 44100
17
- hop_length: 512
18
- num_mels: 128
19
- n_fft: 2048
20
- win_length: 2048
21
-
22
- # Dataset Configuration
23
- train_dataset:
24
- _target_: torch.utils.data.ConcatDataset
25
- datasets:
26
- - _target_: fish_speech.datasets.vqgan.VQGANDataset
27
- filelist: data/gigaspeech/vq_train_filelist.txt
28
- sample_rate: ${sample_rate}
29
- hop_length: ${hop_length}
30
- slice_frames: 512
31
- - _target_: fish_speech.datasets.vqgan.VQGANDataset
32
- filelist: data/sft/vq_train_filelist.txt
33
- sample_rate: ${sample_rate}
34
- hop_length: ${hop_length}
35
- slice_frames: 512
36
-
37
- val_dataset:
38
- _target_: fish_speech.datasets.vqgan.VQGANDataset
39
- filelist: data/sft/vq_val_filelist.txt
40
- sample_rate: ${sample_rate}
41
- hop_length: ${hop_length}
42
-
43
- data:
44
- _target_: fish_speech.datasets.vqgan.VQGANDataModule
45
- train_dataset: ${train_dataset}
46
- val_dataset: ${val_dataset}
47
- num_workers: 4
48
- batch_size: 32
49
- val_batch_size: 32
50
-
51
- # Model Configuration
52
- model:
53
- _target_: fish_speech.models.vqgan.VQGAN
54
-
55
- sampling_rate: ${sample_rate}
56
- weight_adv: 0.2
57
- weight_vq: 1.0
58
- weight_mel: 1.0
59
- freeze_encoder: false
60
-
61
- encoder:
62
- _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
63
- input_channels: ${num_mels}
64
- residual_channels: 768
65
- residual_layers: 20
66
- dilation_cycle: 4
67
-
68
- quantizer:
69
- _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
70
- input_dim: 768
71
- n_codebooks: 1
72
- n_groups: 2
73
- levels: [8, 5, 5, 5]
74
-
75
- decoder:
76
- _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
77
- output_channels: ${num_mels}
78
- residual_channels: 768
79
- residual_layers: 20
80
- dilation_cycle: 4
81
- condition_channels: 768
82
-
83
- discriminator:
84
- _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
85
-
86
- vocoder:
87
- _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
88
- ckpt_path: null # You may download the pretrained vocoder and set the path here
89
-
90
- encode_mel_transform:
91
- _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
92
- sample_rate: ${sample_rate}
93
- n_fft: ${n_fft}
94
- hop_length: ${hop_length}
95
- win_length: ${win_length}
96
- n_mels: ${num_mels}
97
- f_min: 0.0
98
- f_max: 8000.0
99
-
100
- gt_mel_transform:
101
- _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
102
- sample_rate: ${sample_rate}
103
- n_fft: ${n_fft}
104
- hop_length: ${hop_length}
105
- win_length: ${win_length}
106
- n_mels: ${num_mels}
107
-
108
- optimizer:
109
- _target_: torch.optim.AdamW
110
- _partial_: true
111
- lr: 1e-4
112
- betas: [0.8, 0.99]
113
- eps: 1e-5
114
- weight_decay: 0.01
115
-
116
- lr_scheduler:
117
- _target_: torch.optim.lr_scheduler.LambdaLR
118
- _partial_: true
119
- lr_lambda:
120
- _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
121
- _partial_: true
122
- num_warmup_steps: 100
123
- num_training_steps: ${trainer.max_steps}
124
- final_lr_ratio: 0
125
-
126
- callbacks:
127
- model_summary:
128
- _target_: lightning.pytorch.callbacks.ModelSummary
129
- max_depth: 1
130
-
131
- model_checkpoint:
132
- every_n_train_steps: ${trainer.val_check_interval}
133
-
134
- grad_norm_monitor:
135
- sub_module:
136
- - encoder
137
- - decoder
138
- - quantizer
139
- - discriminator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/conversation.py CHANGED
@@ -1,2 +1,267 @@
1
- SEMANTIC_TOKEN = "<|semantic|>"
2
- CODEBOOK_PAD_TOKEN_ID = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from .tokenizer import MODALITY_TOKENS, FishTokenizer
7
+
8
+ CODEBOOK_PAD_TOKEN_ID = 0
9
+
10
+
11
+ @dataclass(kw_only=True)
12
+ class BasePart:
13
+ pass
14
+
15
+
16
+ @dataclass(kw_only=True)
17
+ class VQPart(BasePart):
18
+ codes: torch.Tensor
19
+
20
+
21
+ @dataclass(kw_only=True)
22
+ class TextPart(BasePart):
23
+ text: str
24
+
25
+
26
+ @dataclass(kw_only=True)
27
+ class EncodedMessage:
28
+ tokens: torch.Tensor
29
+ labels: torch.Tensor
30
+ vq_mask_tokens: torch.Tensor | None = None
31
+ vq_mask_labels: torch.Tensor | None = None
32
+ vq_parts: list[torch.Tensor]
33
+ vq_require_losses: torch.Tensor | None = None
34
+
35
+
36
+ @dataclass(kw_only=True)
37
+ class Message:
38
+ role: Literal["system", "user", "assistant"]
39
+ parts: list[VQPart | TextPart] = field(default_factory=list)
40
+ add_im_start: bool = True
41
+ add_im_end: bool = True
42
+ cal_loss: bool = False
43
+ modality: Literal["text", "voice", "interleave"] | None = None
44
+
45
+ # By default, ignore the loss of the auto-generated im_start token
46
+ ignore_im_start_loss: bool = True
47
+
48
+ def encode(
49
+ self: "Message",
50
+ tokenizer: FishTokenizer,
51
+ ) -> EncodedMessage:
52
+ all_tokens = []
53
+ all_labels = []
54
+
55
+ # Multi-modal tokens
56
+ vq_parts = []
57
+ vq_masks = []
58
+
59
+ parts = self.parts.copy()
60
+ if self.add_im_start:
61
+ modality_token = MODALITY_TOKENS[self.modality] if self.modality else ""
62
+ parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n{modality_token}"))
63
+
64
+ if self.add_im_end:
65
+ parts.append(TextPart(text="<|im_end|>"))
66
+
67
+ for part in parts:
68
+ if isinstance(part, TextPart):
69
+ tokens = torch.tensor(
70
+ tokenizer.encode(part.text),
71
+ dtype=torch.int,
72
+ )
73
+ elif isinstance(part, VQPart):
74
+ curr_codes = part.codes.clone()
75
+ tokens = torch.tensor(
76
+ [
77
+ tokenizer.semantic_id_to_token_id[i.item()]
78
+ for i in curr_codes[0].int()
79
+ ],
80
+ dtype=torch.int,
81
+ )
82
+ vq_parts.append(curr_codes)
83
+ else:
84
+ raise ValueError(f"Unsupported part type: {type(part)}")
85
+
86
+ all_tokens.append(tokens)
87
+ if isinstance(part, VQPart):
88
+ vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
89
+ else:
90
+ vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
91
+
92
+ if self.cal_loss:
93
+ all_labels.append(tokens.clone())
94
+ else:
95
+ all_labels.append(torch.full_like(tokens, -100))
96
+
97
+ tokens = torch.cat(all_tokens, dim=0)
98
+ labels = torch.cat(all_labels, dim=0)
99
+ vq_masks = torch.cat(vq_masks, dim=0)
100
+
101
+ assert tokens.shape == labels.shape == vq_masks.shape
102
+
103
+ if self.ignore_im_start_loss and self.add_im_start:
104
+ labels[: len(all_tokens[0])] = -100
105
+
106
+ return EncodedMessage(
107
+ tokens=tokens,
108
+ labels=labels,
109
+ vq_parts=vq_parts,
110
+ vq_mask_tokens=vq_masks,
111
+ vq_mask_labels=vq_masks,
112
+ )
113
+
114
+
115
+ @dataclass
116
+ class Conversation:
117
+ messages: list[Message]
118
+
119
+ def __init__(self: "Conversation", messages: list[Message] | None = None):
120
+ self.messages = messages or []
121
+
122
+ def encode(
123
+ self: "Conversation",
124
+ tokenizer: FishTokenizer,
125
+ add_shift: bool = True,
126
+ ignore_loss_tokens: list[str] = [],
127
+ ) -> EncodedMessage:
128
+ # Build the input_ids and labels
129
+ tokens = []
130
+ labels = []
131
+ vq_parts = []
132
+ vq_mask_tokens = []
133
+ vq_mask_labels = []
134
+ vq_require_losses = []
135
+ ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
136
+
137
+ for message in self.messages:
138
+ encoded = message.encode(
139
+ tokenizer,
140
+ )
141
+ tokens.append(encoded.tokens)
142
+ labels.append(encoded.labels)
143
+ vq_parts.extend(encoded.vq_parts)
144
+ vq_mask_tokens.append(encoded.vq_mask_tokens)
145
+ vq_mask_labels.append(encoded.vq_mask_labels)
146
+ vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
147
+
148
+ tokens = torch.cat(tokens, dim=0)
149
+ labels = torch.cat(labels, dim=0)
150
+ vq_mask_tokens = torch.cat(vq_mask_tokens, dim=0)
151
+ vq_mask_labels = torch.cat(vq_mask_labels, dim=0)
152
+ vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
153
+
154
+ if add_shift:
155
+ tokens = tokens[:-1]
156
+ labels = labels[1:]
157
+ vq_mask_tokens = vq_mask_tokens[:-1]
158
+ vq_mask_labels = vq_mask_labels[1:]
159
+
160
+ for i in ignore_loss_token_ids:
161
+ assert i != -100 and i is not None
162
+ labels[labels == i] = -100
163
+
164
+ assert tokens.dtype in [
165
+ torch.int,
166
+ torch.long,
167
+ ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
168
+
169
+ return EncodedMessage(
170
+ tokens=tokens,
171
+ labels=labels,
172
+ vq_parts=vq_parts,
173
+ vq_mask_tokens=vq_mask_tokens,
174
+ vq_mask_labels=vq_mask_labels,
175
+ vq_require_losses=vq_require_losses,
176
+ )
177
+
178
+ def encode_for_inference(
179
+ self: "Conversation",
180
+ tokenizer: FishTokenizer,
181
+ num_codebooks: int,
182
+ ) -> EncodedMessage:
183
+ # self.visualize(tokenizer)
184
+
185
+ encoded = self.encode(tokenizer, add_shift=False)
186
+ tokens = encoded.tokens
187
+ values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
188
+ values[0] = tokens
189
+
190
+ if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
191
+ return values
192
+
193
+ vq_parts = encoded.vq_parts
194
+ vq_parts = [part.to(values.device) for part in vq_parts]
195
+ vq_parts = torch.cat(vq_parts, dim=1)
196
+ values[0, encoded.vq_mask_tokens] = vq_parts[0] + tokenizer.semantic_begin_id
197
+ values[1:, encoded.vq_mask_tokens] = vq_parts
198
+
199
+ return values
200
+
201
+ def visualize(
202
+ self: "Conversation",
203
+ tokenizer: FishTokenizer,
204
+ ignore_loss_tokens: list[str] = [],
205
+ ):
206
+ encoded = self.encode(
207
+ tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
208
+ )
209
+
210
+ # Colors for alternating tokens
211
+ colors = {
212
+ "blue": "\033[94m", # Light blue
213
+ "cyan": "\033[96m", # Cyan
214
+ "green": "\033[92m", # Light green
215
+ "dark_green": "\033[32m", # Dark green
216
+ }
217
+ blue_idx = 0
218
+ green_idx = 0
219
+
220
+ def print_in_blue(x):
221
+ nonlocal blue_idx
222
+ color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
223
+ print(f"{color}{x}\033[0m", end="")
224
+ blue_idx += 1
225
+
226
+ def print_in_green(x):
227
+ nonlocal green_idx
228
+ color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
229
+ print(f"{color}{x}\033[0m", end="")
230
+ green_idx += 1
231
+
232
+ for tok, lab in zip(encoded.tokens, encoded.labels):
233
+ val = tokenizer.decode([tok])
234
+
235
+ if lab == -100:
236
+ print_in_green(val)
237
+ else:
238
+ print_in_blue(val)
239
+
240
+ print()
241
+
242
+ def append(self: "Conversation", message: Message):
243
+ self.messages.append(message)
244
+
245
+
246
+ if __name__ == "__main__":
247
+ message0 = Message(
248
+ role="user",
249
+ parts=[
250
+ TextPart(text="Hello, how are you?"),
251
+ VQPart(codes=torch.zeros((4, 10))),
252
+ ],
253
+ cal_loss=False,
254
+ )
255
+
256
+ message1 = Message(
257
+ role="assistant",
258
+ parts=[TextPart(text="I'm fine, thank you.")],
259
+ cal_loss=True,
260
+ )
261
+ conversation = Conversation([message0, message1])
262
+ tokenizer = FishTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
263
+ conversation.visualize(tokenizer)
264
+
265
+ encoded = conversation.encode(tokenizer)
266
+ print(encoded)
267
+ print(tokenizer.batch_decode(encoded.tokens))
fish_speech/datasets/concat_repeat.py CHANGED
@@ -1,53 +1,53 @@
1
- import bisect
2
- import random
3
- from typing import Iterable
4
-
5
- from torch.utils.data import Dataset, IterableDataset
6
-
7
-
8
- class ConcatRepeatDataset(Dataset):
9
- datasets: list[Dataset]
10
- cumulative_sizes: list[int]
11
- repeats: list[int]
12
-
13
- @staticmethod
14
- def cumsum(sequence, repeats):
15
- r, s = [], 0
16
- for dataset, repeat in zip(sequence, repeats):
17
- l = len(dataset) * repeat
18
- r.append(l + s)
19
- s += l
20
- return r
21
-
22
- def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
23
- super().__init__()
24
-
25
- self.datasets = list(datasets)
26
- self.repeats = repeats
27
-
28
- assert len(self.datasets) > 0, "datasets should not be an empty iterable"
29
- assert len(self.datasets) == len(
30
- repeats
31
- ), "datasets and repeats should have the same length"
32
-
33
- for d in self.datasets:
34
- assert not isinstance(
35
- d, IterableDataset
36
- ), "ConcatRepeatDataset does not support IterableDataset"
37
-
38
- self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
39
-
40
- def __len__(self):
41
- return self.cumulative_sizes[-1]
42
-
43
- def __getitem__(self, idx):
44
- dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
45
-
46
- if dataset_idx == 0:
47
- sample_idx = idx
48
- else:
49
- sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
50
-
51
- dataset = self.datasets[dataset_idx]
52
-
53
- return dataset[sample_idx % len(dataset)]
 
1
+ import bisect
2
+ import random
3
+ from typing import Iterable
4
+
5
+ from torch.utils.data import Dataset, IterableDataset
6
+
7
+
8
+ class ConcatRepeatDataset(Dataset):
9
+ datasets: list[Dataset]
10
+ cumulative_sizes: list[int]
11
+ repeats: list[int]
12
+
13
+ @staticmethod
14
+ def cumsum(sequence, repeats):
15
+ r, s = [], 0
16
+ for dataset, repeat in zip(sequence, repeats):
17
+ l = len(dataset) * repeat
18
+ r.append(l + s)
19
+ s += l
20
+ return r
21
+
22
+ def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
23
+ super().__init__()
24
+
25
+ self.datasets = list(datasets)
26
+ self.repeats = repeats
27
+
28
+ assert len(self.datasets) > 0, "datasets should not be an empty iterable"
29
+ assert len(self.datasets) == len(
30
+ repeats
31
+ ), "datasets and repeats should have the same length"
32
+
33
+ for d in self.datasets:
34
+ assert not isinstance(
35
+ d, IterableDataset
36
+ ), "ConcatRepeatDataset does not support IterableDataset"
37
+
38
+ self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
39
+
40
+ def __len__(self):
41
+ return self.cumulative_sizes[-1]
42
+
43
+ def __getitem__(self, idx):
44
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
45
+
46
+ if dataset_idx == 0:
47
+ sample_idx = idx
48
+ else:
49
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
50
+
51
+ dataset = self.datasets[dataset_idx]
52
+
53
+ return dataset[sample_idx % len(dataset)]
fish_speech/datasets/protos/text-data.proto CHANGED
@@ -1,24 +1,24 @@
1
- syntax = "proto3";
2
-
3
- package text_data;
4
-
5
- message Semantics {
6
- repeated uint32 values = 1;
7
- }
8
-
9
- message Sentence {
10
- repeated string texts = 1;
11
- repeated Semantics semantics = 3;
12
- }
13
-
14
- message TextData {
15
- string source = 1;
16
- string name = 2;
17
- repeated Sentence sentences = 4;
18
- }
19
-
20
- message SampledData {
21
- string source = 1;
22
- string name = 2;
23
- repeated Sentence samples = 3;
24
- }
 
1
+ syntax = "proto3";
2
+
3
+ package text_data;
4
+
5
+ message Semantics {
6
+ repeated uint32 values = 1;
7
+ }
8
+
9
+ message Sentence {
10
+ repeated string texts = 1;
11
+ repeated Semantics semantics = 3;
12
+ }
13
+
14
+ message TextData {
15
+ string source = 1;
16
+ string name = 2;
17
+ repeated Sentence sentences = 4;
18
+ }
19
+
20
+ message SampledData {
21
+ string source = 1;
22
+ string name = 2;
23
+ repeated Sentence samples = 3;
24
+ }
fish_speech/datasets/protos/text_data_pb2.py CHANGED
@@ -1,33 +1,33 @@
1
- # -*- coding: utf-8 -*-
2
- # Generated by the protocol buffer compiler. DO NOT EDIT!
3
- # source: text-data.proto
4
- # Protobuf Python Version: 4.25.1
5
- """Generated protocol buffer code."""
6
- from google.protobuf import descriptor as _descriptor
7
- from google.protobuf import descriptor_pool as _descriptor_pool
8
- from google.protobuf import symbol_database as _symbol_database
9
- from google.protobuf.internal import builder as _builder
10
-
11
- # @@protoc_insertion_point(imports)
12
-
13
- _sym_db = _symbol_database.Default()
14
-
15
-
16
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
17
- b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
18
- )
19
-
20
- _globals = globals()
21
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
23
- if _descriptor._USE_C_DESCRIPTORS == False:
24
- DESCRIPTOR._options = None
25
- _globals["_SEMANTICS"]._serialized_start = 30
26
- _globals["_SEMANTICS"]._serialized_end = 57
27
- _globals["_SENTENCE"]._serialized_start = 59
28
- _globals["_SENTENCE"]._serialized_end = 125
29
- _globals["_TEXTDATA"]._serialized_start = 127
30
- _globals["_TEXTDATA"]._serialized_end = 207
31
- _globals["_SAMPLEDDATA"]._serialized_start = 209
32
- _globals["_SAMPLEDDATA"]._serialized_end = 290
33
- # @@protoc_insertion_point(module_scope)
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: text-data.proto
4
+ # Protobuf Python Version: 4.25.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+
11
+ # @@protoc_insertion_point(imports)
12
+
13
+ _sym_db = _symbol_database.Default()
14
+
15
+
16
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
17
+ b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
18
+ )
19
+
20
+ _globals = globals()
21
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
23
+ if _descriptor._USE_C_DESCRIPTORS == False:
24
+ DESCRIPTOR._options = None
25
+ _globals["_SEMANTICS"]._serialized_start = 30
26
+ _globals["_SEMANTICS"]._serialized_end = 57
27
+ _globals["_SENTENCE"]._serialized_start = 59
28
+ _globals["_SENTENCE"]._serialized_end = 125
29
+ _globals["_TEXTDATA"]._serialized_start = 127
30
+ _globals["_TEXTDATA"]._serialized_end = 207
31
+ _globals["_SAMPLEDDATA"]._serialized_start = 209
32
+ _globals["_SAMPLEDDATA"]._serialized_end = 290
33
+ # @@protoc_insertion_point(module_scope)
fish_speech/datasets/protos/text_data_stream.py CHANGED
@@ -1,36 +1,36 @@
1
- import struct
2
-
3
- from .text_data_pb2 import TextData
4
-
5
-
6
- def read_pb_stream(f):
7
- while True:
8
- buf = f.read(4)
9
- if len(buf) == 0:
10
- break
11
- size = struct.unpack("I", buf)[0]
12
- buf = f.read(size)
13
- text_data = TextData()
14
- text_data.ParseFromString(buf)
15
- yield text_data
16
-
17
-
18
- def write_pb_stream(f, text_data):
19
- buf = text_data.SerializeToString()
20
- f.write(struct.pack("I", len(buf)))
21
- f.write(buf)
22
-
23
-
24
- def pack_pb_stream(text_data):
25
- buf = text_data.SerializeToString()
26
- return struct.pack("I", len(buf)) + buf
27
-
28
-
29
- def split_pb_stream(f):
30
- while True:
31
- head = f.read(4)
32
- if len(head) == 0:
33
- break
34
- size = struct.unpack("I", head)[0]
35
- buf = f.read(size)
36
- yield head + buf
 
1
+ import struct
2
+
3
+ from .text_data_pb2 import TextData
4
+
5
+
6
+ def read_pb_stream(f):
7
+ while True:
8
+ buf = f.read(4)
9
+ if len(buf) == 0:
10
+ break
11
+ size = struct.unpack("I", buf)[0]
12
+ buf = f.read(size)
13
+ text_data = TextData()
14
+ text_data.ParseFromString(buf)
15
+ yield text_data
16
+
17
+
18
+ def write_pb_stream(f, text_data):
19
+ buf = text_data.SerializeToString()
20
+ f.write(struct.pack("I", len(buf)))
21
+ f.write(buf)
22
+
23
+
24
+ def pack_pb_stream(text_data):
25
+ buf = text_data.SerializeToString()
26
+ return struct.pack("I", len(buf)) + buf
27
+
28
+
29
+ def split_pb_stream(f):
30
+ while True:
31
+ head = f.read(4)
32
+ if len(head) == 0:
33
+ break
34
+ size = struct.unpack("I", head)[0]
35
+ buf = f.read(size)
36
+ yield head + buf
fish_speech/datasets/semantic.py CHANGED
@@ -1,496 +1,496 @@
1
- import random
2
- from dataclasses import dataclass
3
- from itertools import chain
4
- from pathlib import Path
5
- from random import Random
6
- from typing import Optional, Union
7
-
8
- import numpy as np
9
- import pyarrow.parquet as pq
10
- import torch
11
- import torch.nn.functional as F
12
- from datasets.download.streaming_download_manager import xopen
13
- from huggingface_hub import HfApi
14
- from lightning import LightningDataModule
15
- from torch.distributed import get_rank, get_world_size, is_initialized
16
- from torch.utils.data import DataLoader, IterableDataset, get_worker_info
17
- from transformers import AutoTokenizer
18
-
19
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
- from fish_speech.datasets.protos.text_data_pb2 import SampledData
21
- from fish_speech.datasets.protos.text_data_stream import read_pb_stream
22
- from fish_speech.text.clean import clean_text
23
- from fish_speech.utils import RankedLogger
24
- from fish_speech.utils.braceexpand import braceexpand
25
-
26
- log = RankedLogger(__name__, rank_zero_only=True)
27
-
28
-
29
- def split_by_rank_worker(files):
30
- # We need to know the total number of devices
31
- # to split the data properly
32
-
33
- total_devices = 1
34
- if is_initialized():
35
- total_devices = get_world_size()
36
-
37
- worker_info = get_worker_info()
38
- if worker_info is not None:
39
- total_devices *= worker_info.num_workers
40
-
41
- if len(files) < total_devices:
42
- # Repeat the files N times to match the number of devices
43
- files = files * (total_devices // len(files) + 1)
44
-
45
- # DDP
46
- if is_initialized():
47
- files = files[get_rank() :: get_world_size()]
48
-
49
- # Split by worker
50
- if worker_info is not None:
51
- files = files[worker_info.id :: worker_info.num_workers]
52
-
53
- return files
54
-
55
-
56
- class AutoTextSemanticInstructionDataset(IterableDataset):
57
- """
58
- Auto Augment Dataset by Speaker
59
-
60
- 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
61
- 2. Automatically normalize the text
62
-
63
- For interactive mode, we use the following format (multiple sequences):
64
- <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
65
-
66
- For non-interactive mode, we use the following format (one long sequence):
67
- <s> [INST] text [/INST] ... </s>
68
- """
69
-
70
- def __init__(
71
- self,
72
- proto_files: list[str],
73
- seed: int = 42,
74
- interactive_prob: float = 0.5,
75
- max_length: int = 1024,
76
- tokenizer: AutoTokenizer = None,
77
- use_speaker: bool | float = True,
78
- causal: bool = True,
79
- num_codebooks: Optional[int] = None,
80
- skip_text_prob: float = 0.0,
81
- ):
82
- """
83
- Args:
84
- proto_files: proto buf files if using local data
85
- seed: random seed
86
- interactive_prob: probability to use interactive mode
87
- max_length: max length of the text
88
- tokenizer: tokenizer
89
- use_speaker: include speaker information in the prompt
90
- causal: use causal sampling when using local data, disable will lead to random sampling
91
- num_codebooks: number of codebooks, if None, it will be automatically detected
92
- skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
93
- """
94
-
95
- super().__init__()
96
-
97
- assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
98
-
99
- self.seed = seed
100
- self.max_length = max_length
101
- self.tokenizer = tokenizer
102
- self.interactive_prob = interactive_prob
103
- self.use_speaker = use_speaker
104
- self.proto_files = proto_files
105
- self.causal = causal
106
- self.num_codebooks = num_codebooks
107
- self.skip_text_prob = skip_text_prob
108
-
109
- self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
110
- self.groups = None
111
-
112
- def init_mock_data_server(self):
113
- if self.groups is not None:
114
- return
115
-
116
- # Expand the proto files
117
- expanded_proto_files = []
118
- for filename in self.proto_files:
119
- for i in braceexpand(filename):
120
- i = Path(i)
121
- if i.is_file():
122
- expanded_proto_files.append(i)
123
- elif i.is_dir():
124
- expanded_proto_files.extend(i.rglob("*.proto"))
125
- expanded_proto_files.extend(i.rglob("*.protos"))
126
- else:
127
- raise ValueError(f"{i} is not a file or directory")
128
-
129
- expanded_proto_files = sorted(expanded_proto_files)
130
- Random(self.seed).shuffle(expanded_proto_files)
131
-
132
- self.groups = []
133
- shard_proto_files = split_by_rank_worker(expanded_proto_files)
134
- log.info(
135
- f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
136
- )
137
-
138
- count = 0
139
- for filename in shard_proto_files:
140
- with open(filename, "rb") as f:
141
- for text_data in read_pb_stream(f):
142
- self.groups.append(text_data)
143
- count += 1
144
-
145
- log.info(f"Read total {count} groups of data")
146
-
147
- # Shuffle the lines
148
- Random(self.seed).shuffle(self.groups)
149
- self.group_weights = [len(i.sentences) for i in self.groups]
150
-
151
- def __iter__(self):
152
- while True:
153
- yield self.augment()
154
-
155
- def tokenize_sentence(self, sentence: str):
156
- sentence = clean_text(sentence)
157
- tokens = self.tokenizer.encode(
158
- f"{sentence}",
159
- max_length=10**6,
160
- add_special_tokens=False,
161
- truncation=False,
162
- )
163
- return sentence, len(tokens)
164
-
165
- def sample_data(self):
166
- if self.groups is None:
167
- self.init_mock_data_server()
168
-
169
- # Shuffle unique lines, estimate that each sample is at least 20 tokens
170
- num_samples = self.max_length // 20
171
-
172
- # choice group based on their number of samples
173
- group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
174
-
175
- if self.causal:
176
- # Sample in order
177
- if num_samples >= len(group.sentences):
178
- samples = group.sentences
179
- else:
180
- begin = random.randint(0, len(group.sentences) - num_samples)
181
- samples = group.sentences[begin : begin + num_samples]
182
- else:
183
- samples = random.choices(
184
- group.sentences, k=min(num_samples, len(group.sentences))
185
- )
186
-
187
- return SampledData(
188
- source=group.source,
189
- name=group.name,
190
- samples=samples,
191
- )
192
-
193
- def augment(self):
194
- final_text, final_semantic = [], []
195
- response = self.sample_data()
196
- if len(response.samples) == 0:
197
- # Invalid group
198
- return None
199
-
200
- samples = list(response.samples)
201
- idx = 0
202
- use_interactive = random.random() < self.interactive_prob
203
-
204
- if use_interactive is False:
205
- # Random sample based on speaker using a truncated normal distribution
206
- a = torch.tensor([0], dtype=torch.float32)
207
- torch.nn.init.trunc_normal_(
208
- a,
209
- mean=self.max_length // 2,
210
- std=self.max_length // 4,
211
- a=10,
212
- b=self.max_length,
213
- )
214
- remaining_tokens = a.long().item() - 4
215
- else:
216
- remaining_tokens = self.max_length
217
-
218
- # Use speaker
219
- if isinstance(self.use_speaker, float):
220
- use_speaker = random.random() < self.use_speaker
221
- else:
222
- use_speaker = self.use_speaker
223
-
224
- all_tokens, all_labels = [], []
225
- while remaining_tokens > 0 and len(samples) > 0:
226
- sentence = samples.pop(0)
227
-
228
- text = random.choice(sentence.texts)
229
- text, length = self.tokenize_sentence(text)
230
- remaining_tokens -= length + len(sentence.semantics[0].values)
231
-
232
- if use_interactive is False:
233
- final_text.append(text)
234
- final_semantic.append(sentence.semantics)
235
- else:
236
- # For interactive mode, we only apply speaker for the first sentence
237
- # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
238
- tokens, labels = self.pack_sentences(
239
- sentences=[text],
240
- semantics=[sentence.semantics],
241
- speaker=response.name if use_speaker else None,
242
- skip_text=random.random() < self.skip_text_prob,
243
- )
244
-
245
- all_tokens.append(tokens)
246
- all_labels.append(labels)
247
-
248
- idx += 1
249
-
250
- if use_interactive is False:
251
- tokens, labels = self.pack_sentences(
252
- final_text,
253
- semantics=final_semantic,
254
- speaker=response.name if use_speaker else None,
255
- )
256
- all_tokens.append(tokens)
257
- all_labels.append(labels)
258
-
259
- tokens = torch.cat(all_tokens, dim=1)
260
- labels = torch.cat(all_labels, dim=1)
261
-
262
- # Verify that the length is correct
263
- assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
264
-
265
- data = {"tokens": tokens, "labels": labels}
266
-
267
- return data
268
-
269
- def pack_sentences(
270
- self,
271
- sentences: list[str],
272
- semantics: list,
273
- speaker: Optional[str] = None,
274
- skip_text: bool = False,
275
- ):
276
- if speaker is None:
277
- speaker = "assistant"
278
-
279
- cated_sentences = " ".join(sentences)
280
- if skip_text:
281
- cated_sentences = "<|skip_text|>"
282
-
283
- final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
284
- final_text = final_text + f"<|im_start|>{speaker}\n"
285
-
286
- encoded = self.tokenizer.encode(
287
- final_text,
288
- add_special_tokens=False,
289
- truncation=False,
290
- max_length=10**6,
291
- )
292
- semantic_length = sum([len(i[0].values) for i in semantics])
293
- prompt_length = len(encoded)
294
- num_codebooks = (
295
- len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
296
- )
297
-
298
- # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
299
- tokens = (
300
- encoded
301
- + [self.semantic_token_id] * semantic_length
302
- + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
303
- )
304
-
305
- # Codebook bos/padding: 0, eos: 1
306
- codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
307
- for segment in semantics:
308
- for book_idx, book in zip(range(num_codebooks), segment):
309
- for j in book.values:
310
- codes[book_idx].append(int(j) + 1)
311
-
312
- for book in codes:
313
- book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
314
-
315
- tokens = [tokens] + codes
316
-
317
- tokens = torch.tensor(tokens, dtype=torch.long)
318
- labels = tokens.clone()
319
-
320
- if skip_text:
321
- # If text is not provided, the sentence is used for condition only, all labels are -100
322
- torch.fill_(labels, -100)
323
- return tokens, labels
324
-
325
- # Mask out the <s> tokens for semantic, predict semantic tokens only
326
- # Since we don't mask out the input tokens, the language modeling still works
327
- labels[1:, :prompt_length] = -100
328
-
329
- tokens = tokens[:, :-1]
330
- labels = labels[:, 1:]
331
-
332
- # Verify the padding is correct, and the last token is eos
333
- assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
334
- assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
335
-
336
- return tokens, labels
337
-
338
-
339
- @dataclass
340
- class TextDataCollator:
341
- tokenizer: AutoTokenizer
342
- max_length: int = 1024
343
-
344
- def __call__(self, examples):
345
- if "negative_tokens" in examples:
346
- positive_examples = []
347
- negative_examples = []
348
-
349
- for i in examples:
350
- positive_examples.append(
351
- {
352
- "tokens": i["tokens"],
353
- "labels": i["labels"],
354
- }
355
- )
356
- negative_examples.append(
357
- {
358
- "tokens": i["negative_tokens"],
359
- "labels": i["negative_labels"],
360
- }
361
- )
362
-
363
- examples = positive_examples + negative_examples
364
-
365
- return self.batchify(examples)
366
-
367
- def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
368
- tokens, attention_masks, labels = [], [], []
369
-
370
- # Calculate the max length
371
- max_tokens_length = 0
372
- for example in examples:
373
- max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
374
- max_tokens_length = min(max_tokens_length, self.max_length)
375
-
376
- for example in examples:
377
- _tokens = example[tokens_key][:, :max_tokens_length]
378
- _labels = example[labels_key][:, :max_tokens_length]
379
- _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
380
- tokens_length = _tokens.size(1)
381
- _attention_mask[:tokens_length] = False
382
-
383
- assert tokens_length == _labels.size(
384
- 1
385
- ), f"{tokens_length} != {_labels.size(1)}"
386
-
387
- if tokens_length < max_tokens_length:
388
- _tokens = F.pad(
389
- _tokens,
390
- (0, max_tokens_length - tokens_length),
391
- value=self.tokenizer.eos_token_id,
392
- )
393
- _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
394
- _labels = F.pad(
395
- _labels, (0, max_tokens_length - _labels.size(1)), value=-100
396
- )
397
-
398
- tokens.append(_tokens)
399
- attention_masks.append(_attention_mask)
400
- labels.append(_labels)
401
-
402
- tokens = torch.stack(tokens, dim=0)
403
- attention_masks = torch.stack(attention_masks, dim=0)
404
- labels = torch.stack(labels, dim=0)
405
-
406
- return {
407
- "inputs": tokens,
408
- "attention_masks": attention_masks,
409
- "labels": labels,
410
- }
411
-
412
-
413
- class InterleaveDataset(IterableDataset):
414
- def __init__(
415
- self,
416
- datasets: list[IterableDataset],
417
- probabilities: list[float],
418
- seed: int = 42,
419
- ):
420
- super().__init__()
421
-
422
- self.datasets = datasets
423
- self.probabilities = probabilities
424
- self.seed = seed
425
-
426
- def __iter__(self):
427
- rng = np.random.default_rng(self.seed)
428
- dataset_iterators = [iter(dataset) for dataset in self.datasets]
429
-
430
- while True:
431
- # Random choice one
432
- dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
433
- dataset_iterator = dataset_iterators[dataset_idx]
434
-
435
- try:
436
- yield next(dataset_iterator)
437
- except StopIteration:
438
- # Exhausted, create a new iterator
439
- dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
440
- yield next(dataset_iterators[dataset_idx])
441
-
442
-
443
- class SemanticDataModule(LightningDataModule):
444
- def __init__(
445
- self,
446
- train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
447
- val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
448
- batch_size: int = 32,
449
- tokenizer: AutoTokenizer = None,
450
- max_length: int = 1024,
451
- num_workers: int = 4,
452
- ):
453
- super().__init__()
454
-
455
- self.train_dataset = train_dataset
456
- self.val_dataset = val_dataset
457
- self.batch_size = batch_size
458
- self.tokenizer = tokenizer
459
- self.max_length = max_length
460
- self.num_workers = num_workers
461
-
462
- def train_dataloader(self):
463
- return DataLoader(
464
- self.train_dataset,
465
- batch_size=self.batch_size,
466
- collate_fn=TextDataCollator(self.tokenizer, self.max_length),
467
- num_workers=self.num_workers,
468
- persistent_workers=True,
469
- )
470
-
471
- def val_dataloader(self):
472
- return DataLoader(
473
- self.val_dataset,
474
- batch_size=self.batch_size,
475
- collate_fn=TextDataCollator(self.tokenizer, self.max_length),
476
- num_workers=self.num_workers,
477
- persistent_workers=True,
478
- )
479
-
480
-
481
- if __name__ == "__main__":
482
- from tqdm import tqdm
483
-
484
- ds = AutoTextSemanticInstructionDataset(
485
- ["data/protos"],
486
- tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
487
- use_speaker=False,
488
- interactive_prob=1.0,
489
- skip_text_prob=0.5,
490
- )
491
-
492
- for i in ds:
493
- print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
494
- # i["labels"][0][i["labels"][0] == -100] = 0
495
- # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
496
- break
 
1
+ import random
2
+ from dataclasses import dataclass
3
+ from itertools import chain
4
+ from pathlib import Path
5
+ from random import Random
6
+ from typing import Optional, Union
7
+
8
+ import numpy as np
9
+ import pyarrow.parquet as pq
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from datasets.download.streaming_download_manager import xopen
13
+ from huggingface_hub import HfApi
14
+ from lightning import LightningDataModule
15
+ from torch.distributed import get_rank, get_world_size, is_initialized
16
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
17
+ from transformers import AutoTokenizer
18
+
19
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
+ from fish_speech.datasets.protos.text_data_pb2 import SampledData
21
+ from fish_speech.datasets.protos.text_data_stream import read_pb_stream
22
+ from fish_speech.text.clean import clean_text
23
+ from fish_speech.utils import RankedLogger
24
+ from fish_speech.utils.braceexpand import braceexpand
25
+
26
+ log = RankedLogger(__name__, rank_zero_only=True)
27
+
28
+
29
+ def split_by_rank_worker(files):
30
+ # We need to know the total number of devices
31
+ # to split the data properly
32
+
33
+ total_devices = 1
34
+ if is_initialized():
35
+ total_devices = get_world_size()
36
+
37
+ worker_info = get_worker_info()
38
+ if worker_info is not None:
39
+ total_devices *= worker_info.num_workers
40
+
41
+ if len(files) < total_devices:
42
+ # Repeat the files N times to match the number of devices
43
+ files = files * (total_devices // len(files) + 1)
44
+
45
+ # DDP
46
+ if is_initialized():
47
+ files = files[get_rank() :: get_world_size()]
48
+
49
+ # Split by worker
50
+ if worker_info is not None:
51
+ files = files[worker_info.id :: worker_info.num_workers]
52
+
53
+ return files
54
+
55
+
56
+ class AutoTextSemanticInstructionDataset(IterableDataset):
57
+ """
58
+ Auto Augment Dataset by Speaker
59
+
60
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
61
+ 2. Automatically normalize the text
62
+
63
+ For interactive mode, we use the following format (multiple sequences):
64
+ <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
65
+
66
+ For non-interactive mode, we use the following format (one long sequence):
67
+ <s> [INST] text [/INST] ... </s>
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ proto_files: list[str],
73
+ seed: int = 42,
74
+ interactive_prob: float = 0.5,
75
+ max_length: int = 1024,
76
+ tokenizer: AutoTokenizer = None,
77
+ use_speaker: bool | float = True,
78
+ causal: bool = True,
79
+ num_codebooks: Optional[int] = None,
80
+ skip_text_prob: float = 0.0,
81
+ ):
82
+ """
83
+ Args:
84
+ proto_files: proto buf files if using local data
85
+ seed: random seed
86
+ interactive_prob: probability to use interactive mode
87
+ max_length: max length of the text
88
+ tokenizer: tokenizer
89
+ use_speaker: include speaker information in the prompt
90
+ causal: use causal sampling when using local data, disable will lead to random sampling
91
+ num_codebooks: number of codebooks, if None, it will be automatically detected
92
+ skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
93
+ """
94
+
95
+ super().__init__()
96
+
97
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
98
+
99
+ self.seed = seed
100
+ self.max_length = max_length
101
+ self.tokenizer = tokenizer
102
+ self.interactive_prob = interactive_prob
103
+ self.use_speaker = use_speaker
104
+ self.proto_files = proto_files
105
+ self.causal = causal
106
+ self.num_codebooks = num_codebooks
107
+ self.skip_text_prob = skip_text_prob
108
+
109
+ self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
110
+ self.groups = None
111
+
112
+ def init_mock_data_server(self):
113
+ if self.groups is not None:
114
+ return
115
+
116
+ # Expand the proto files
117
+ expanded_proto_files = []
118
+ for filename in self.proto_files:
119
+ for i in braceexpand(filename):
120
+ i = Path(i)
121
+ if i.is_file():
122
+ expanded_proto_files.append(i)
123
+ elif i.is_dir():
124
+ expanded_proto_files.extend(i.rglob("*.proto"))
125
+ expanded_proto_files.extend(i.rglob("*.protos"))
126
+ else:
127
+ raise ValueError(f"{i} is not a file or directory")
128
+
129
+ expanded_proto_files = sorted(expanded_proto_files)
130
+ Random(self.seed).shuffle(expanded_proto_files)
131
+
132
+ self.groups = []
133
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
134
+ log.info(
135
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
136
+ )
137
+
138
+ count = 0
139
+ for filename in shard_proto_files:
140
+ with open(filename, "rb") as f:
141
+ for text_data in read_pb_stream(f):
142
+ self.groups.append(text_data)
143
+ count += 1
144
+
145
+ log.info(f"Read total {count} groups of data")
146
+
147
+ # Shuffle the lines
148
+ Random(self.seed).shuffle(self.groups)
149
+ self.group_weights = [len(i.sentences) for i in self.groups]
150
+
151
+ def __iter__(self):
152
+ while True:
153
+ yield self.augment()
154
+
155
+ def tokenize_sentence(self, sentence: str):
156
+ sentence = clean_text(sentence)
157
+ tokens = self.tokenizer.encode(
158
+ f"{sentence}",
159
+ max_length=10**6,
160
+ add_special_tokens=False,
161
+ truncation=False,
162
+ )
163
+ return sentence, len(tokens)
164
+
165
+ def sample_data(self):
166
+ if self.groups is None:
167
+ self.init_mock_data_server()
168
+
169
+ # Shuffle unique lines, estimate that each sample is at least 20 tokens
170
+ num_samples = self.max_length // 20
171
+
172
+ # choice group based on their number of samples
173
+ group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
174
+
175
+ if self.causal:
176
+ # Sample in order
177
+ if num_samples >= len(group.sentences):
178
+ samples = group.sentences
179
+ else:
180
+ begin = random.randint(0, len(group.sentences) - num_samples)
181
+ samples = group.sentences[begin : begin + num_samples]
182
+ else:
183
+ samples = random.choices(
184
+ group.sentences, k=min(num_samples, len(group.sentences))
185
+ )
186
+
187
+ return SampledData(
188
+ source=group.source,
189
+ name=group.name,
190
+ samples=samples,
191
+ )
192
+
193
+ def augment(self):
194
+ final_text, final_semantic = [], []
195
+ response = self.sample_data()
196
+ if len(response.samples) == 0:
197
+ # Invalid group
198
+ return None
199
+
200
+ samples = list(response.samples)
201
+ idx = 0
202
+ use_interactive = random.random() < self.interactive_prob
203
+
204
+ if use_interactive is False:
205
+ # Random sample based on speaker using a truncated normal distribution
206
+ a = torch.tensor([0], dtype=torch.float32)
207
+ torch.nn.init.trunc_normal_(
208
+ a,
209
+ mean=self.max_length // 2,
210
+ std=self.max_length // 4,
211
+ a=10,
212
+ b=self.max_length,
213
+ )
214
+ remaining_tokens = a.long().item() - 4
215
+ else:
216
+ remaining_tokens = self.max_length
217
+
218
+ # Use speaker
219
+ if isinstance(self.use_speaker, float):
220
+ use_speaker = random.random() < self.use_speaker
221
+ else:
222
+ use_speaker = self.use_speaker
223
+
224
+ all_tokens, all_labels = [], []
225
+ while remaining_tokens > 0 and len(samples) > 0:
226
+ sentence = samples.pop(0)
227
+
228
+ text = random.choice(sentence.texts)
229
+ text, length = self.tokenize_sentence(text)
230
+ remaining_tokens -= length + len(sentence.semantics[0].values)
231
+
232
+ if use_interactive is False:
233
+ final_text.append(text)
234
+ final_semantic.append(sentence.semantics)
235
+ else:
236
+ # For interactive mode, we only apply speaker for the first sentence
237
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
238
+ tokens, labels = self.pack_sentences(
239
+ sentences=[text],
240
+ semantics=[sentence.semantics],
241
+ speaker=response.name if use_speaker else None,
242
+ skip_text=random.random() < self.skip_text_prob,
243
+ )
244
+
245
+ all_tokens.append(tokens)
246
+ all_labels.append(labels)
247
+
248
+ idx += 1
249
+
250
+ if use_interactive is False:
251
+ tokens, labels = self.pack_sentences(
252
+ final_text,
253
+ semantics=final_semantic,
254
+ speaker=response.name if use_speaker else None,
255
+ )
256
+ all_tokens.append(tokens)
257
+ all_labels.append(labels)
258
+
259
+ tokens = torch.cat(all_tokens, dim=1)
260
+ labels = torch.cat(all_labels, dim=1)
261
+
262
+ # Verify that the length is correct
263
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
264
+
265
+ data = {"tokens": tokens, "labels": labels}
266
+
267
+ return data
268
+
269
+ def pack_sentences(
270
+ self,
271
+ sentences: list[str],
272
+ semantics: list,
273
+ speaker: Optional[str] = None,
274
+ skip_text: bool = False,
275
+ ):
276
+ if speaker is None:
277
+ speaker = "assistant"
278
+
279
+ cated_sentences = " ".join(sentences)
280
+ if skip_text:
281
+ cated_sentences = "<|skip_text|>"
282
+
283
+ final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
284
+ final_text = final_text + f"<|im_start|>{speaker}\n"
285
+
286
+ encoded = self.tokenizer.encode(
287
+ final_text,
288
+ add_special_tokens=False,
289
+ truncation=False,
290
+ max_length=10**6,
291
+ )
292
+ semantic_length = sum([len(i[0].values) for i in semantics])
293
+ prompt_length = len(encoded)
294
+ num_codebooks = (
295
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
296
+ )
297
+
298
+ # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
299
+ tokens = (
300
+ encoded
301
+ + [self.semantic_token_id] * semantic_length
302
+ + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
303
+ )
304
+
305
+ # Codebook bos/padding: 0, eos: 1
306
+ codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
307
+ for segment in semantics:
308
+ for book_idx, book in zip(range(num_codebooks), segment):
309
+ for j in book.values:
310
+ codes[book_idx].append(int(j) + 1)
311
+
312
+ for book in codes:
313
+ book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
314
+
315
+ tokens = [tokens] + codes
316
+
317
+ tokens = torch.tensor(tokens, dtype=torch.long)
318
+ labels = tokens.clone()
319
+
320
+ if skip_text:
321
+ # If text is not provided, the sentence is used for condition only, all labels are -100
322
+ torch.fill_(labels, -100)
323
+ return tokens, labels
324
+
325
+ # Mask out the <s> tokens for semantic, predict semantic tokens only
326
+ # Since we don't mask out the input tokens, the language modeling still works
327
+ labels[1:, :prompt_length] = -100
328
+
329
+ tokens = tokens[:, :-1]
330
+ labels = labels[:, 1:]
331
+
332
+ # Verify the padding is correct, and the last token is eos
333
+ assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
334
+ assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
335
+
336
+ return tokens, labels
337
+
338
+
339
+ @dataclass
340
+ class TextDataCollator:
341
+ tokenizer: AutoTokenizer
342
+ max_length: int = 1024
343
+
344
+ def __call__(self, examples):
345
+ if "negative_tokens" in examples:
346
+ positive_examples = []
347
+ negative_examples = []
348
+
349
+ for i in examples:
350
+ positive_examples.append(
351
+ {
352
+ "tokens": i["tokens"],
353
+ "labels": i["labels"],
354
+ }
355
+ )
356
+ negative_examples.append(
357
+ {
358
+ "tokens": i["negative_tokens"],
359
+ "labels": i["negative_labels"],
360
+ }
361
+ )
362
+
363
+ examples = positive_examples + negative_examples
364
+
365
+ return self.batchify(examples)
366
+
367
+ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
368
+ tokens, attention_masks, labels = [], [], []
369
+
370
+ # Calculate the max length
371
+ max_tokens_length = 0
372
+ for example in examples:
373
+ max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
374
+ max_tokens_length = min(max_tokens_length, self.max_length)
375
+
376
+ for example in examples:
377
+ _tokens = example[tokens_key][:, :max_tokens_length]
378
+ _labels = example[labels_key][:, :max_tokens_length]
379
+ _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
380
+ tokens_length = _tokens.size(1)
381
+ _attention_mask[:tokens_length] = False
382
+
383
+ assert tokens_length == _labels.size(
384
+ 1
385
+ ), f"{tokens_length} != {_labels.size(1)}"
386
+
387
+ if tokens_length < max_tokens_length:
388
+ _tokens = F.pad(
389
+ _tokens,
390
+ (0, max_tokens_length - tokens_length),
391
+ value=self.tokenizer.eos_token_id,
392
+ )
393
+ _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
394
+ _labels = F.pad(
395
+ _labels, (0, max_tokens_length - _labels.size(1)), value=-100
396
+ )
397
+
398
+ tokens.append(_tokens)
399
+ attention_masks.append(_attention_mask)
400
+ labels.append(_labels)
401
+
402
+ tokens = torch.stack(tokens, dim=0)
403
+ attention_masks = torch.stack(attention_masks, dim=0)
404
+ labels = torch.stack(labels, dim=0)
405
+
406
+ return {
407
+ "inputs": tokens,
408
+ "attention_masks": attention_masks,
409
+ "labels": labels,
410
+ }
411
+
412
+
413
+ class InterleaveDataset(IterableDataset):
414
+ def __init__(
415
+ self,
416
+ datasets: list[IterableDataset],
417
+ probabilities: list[float],
418
+ seed: int = 42,
419
+ ):
420
+ super().__init__()
421
+
422
+ self.datasets = datasets
423
+ self.probabilities = probabilities
424
+ self.seed = seed
425
+
426
+ def __iter__(self):
427
+ rng = np.random.default_rng(self.seed)
428
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
429
+
430
+ while True:
431
+ # Random choice one
432
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
433
+ dataset_iterator = dataset_iterators[dataset_idx]
434
+
435
+ try:
436
+ yield next(dataset_iterator)
437
+ except StopIteration:
438
+ # Exhausted, create a new iterator
439
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
440
+ yield next(dataset_iterators[dataset_idx])
441
+
442
+
443
+ class SemanticDataModule(LightningDataModule):
444
+ def __init__(
445
+ self,
446
+ train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
447
+ val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
448
+ batch_size: int = 32,
449
+ tokenizer: AutoTokenizer = None,
450
+ max_length: int = 1024,
451
+ num_workers: int = 4,
452
+ ):
453
+ super().__init__()
454
+
455
+ self.train_dataset = train_dataset
456
+ self.val_dataset = val_dataset
457
+ self.batch_size = batch_size
458
+ self.tokenizer = tokenizer
459
+ self.max_length = max_length
460
+ self.num_workers = num_workers
461
+
462
+ def train_dataloader(self):
463
+ return DataLoader(
464
+ self.train_dataset,
465
+ batch_size=self.batch_size,
466
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
467
+ num_workers=self.num_workers,
468
+ persistent_workers=True,
469
+ )
470
+
471
+ def val_dataloader(self):
472
+ return DataLoader(
473
+ self.val_dataset,
474
+ batch_size=self.batch_size,
475
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
476
+ num_workers=self.num_workers,
477
+ persistent_workers=True,
478
+ )
479
+
480
+
481
+ if __name__ == "__main__":
482
+ from tqdm import tqdm
483
+
484
+ ds = AutoTextSemanticInstructionDataset(
485
+ ["data/protos"],
486
+ tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
487
+ use_speaker=False,
488
+ interactive_prob=1.0,
489
+ skip_text_prob=0.5,
490
+ )
491
+
492
+ for i in ds:
493
+ print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
494
+ # i["labels"][0][i["labels"][0] == -100] = 0
495
+ # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
496
+ break
fish_speech/datasets/text.py DELETED
@@ -1,661 +0,0 @@
1
- import random
2
- from dataclasses import dataclass
3
- from itertools import chain
4
- from pathlib import Path
5
- from random import Random
6
- from typing import Optional, Union
7
-
8
- import grpc
9
- import numpy as np
10
- import pyarrow.parquet as pq
11
- import torch
12
- import torch.nn.functional as F
13
- from datasets.download.streaming_download_manager import xopen
14
- from huggingface_hub import HfApi
15
- from lightning import LightningDataModule
16
- from torch.distributed import get_rank, get_world_size, is_initialized
17
- from torch.utils.data import DataLoader, IterableDataset, get_worker_info
18
- from transformers import AutoTokenizer
19
-
20
- from fish_speech.datasets.protos.text_data_pb2 import SampledData
21
- from fish_speech.datasets.protos.text_data_stream import read_pb_stream
22
- from fish_speech.text.clean import clean_text
23
- from fish_speech.utils import RankedLogger
24
- from fish_speech.utils.braceexpand import braceexpand
25
-
26
- log = RankedLogger(__name__, rank_zero_only=True)
27
-
28
- CODEBOOK_PAD_TOKEN_ID = 0
29
- CODEBOOK_EOS_TOKEN_ID = 1
30
-
31
-
32
- def split_by_rank_worker(files):
33
- # We need to know the total number of devices
34
- # to split the data properly
35
-
36
- total_devices = 1
37
- if is_initialized():
38
- total_devices = get_world_size()
39
-
40
- worker_info = get_worker_info()
41
- if worker_info is not None:
42
- total_devices *= worker_info.num_workers
43
-
44
- if len(files) < total_devices:
45
- # Repeat the files N times to match the number of devices
46
- files = files * (total_devices // len(files) + 1)
47
-
48
- # DDP
49
- if is_initialized():
50
- files = files[get_rank() :: get_world_size()]
51
-
52
- # Split by worker
53
- if worker_info is not None:
54
- files = files[worker_info.id :: worker_info.num_workers]
55
-
56
- return files
57
-
58
-
59
- class StreamTextDataset(IterableDataset):
60
- def __init__(
61
- self,
62
- files: Optional[Union[list[str], str]] = None,
63
- prefix: Optional[str] = None,
64
- seed: int = 42,
65
- parquet_batch_size: int = 10000,
66
- repo: str = "uonlp/CulturaX",
67
- max_length: int = 1024,
68
- tokenizer: AutoTokenizer = None,
69
- ):
70
- super().__init__()
71
-
72
- self.seed = seed
73
- self.parquet_batch_size = parquet_batch_size
74
- self.repo = repo
75
- self.max_length = max_length
76
- self.tokenizer = tokenizer
77
-
78
- if files is None and prefix is None:
79
- raise ValueError("Either files or prefix must be specified")
80
-
81
- if prefix is not None:
82
- files = HfApi().list_repo_files(repo, repo_type="dataset")
83
- files = [
84
- f for f in files if f.startswith(prefix) and f.endswith(".parquet")
85
- ]
86
- log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
87
- else:
88
- if isinstance(files, str):
89
- files = [files]
90
-
91
- files = list(chain.from_iterable(map(braceexpand, files)))
92
- log.info(f"Expanded {len(files)} files in {repo}")
93
-
94
- # Get sharded files
95
- self.files = sorted(files)
96
- Random(seed).shuffle(self.files)
97
-
98
- def __iter__(self):
99
- files = split_by_rank_worker(self.files)
100
- random.shuffle(files)
101
-
102
- for filename in files:
103
- try:
104
- yield from self.parse_data(filename)
105
- except Exception as e:
106
- log.exception(f"Failed to parse {filename}: {e}")
107
-
108
- def parse_data(self, filename: str):
109
- for data in self.parse_data_internal(filename):
110
- text = data["text"]
111
-
112
- # encode
113
- tokens = self.tokenizer.encode(
114
- text,
115
- add_special_tokens=False,
116
- truncation=False,
117
- max_length=10**6,
118
- )
119
-
120
- # Random choice self.max_length
121
- if len(tokens) > self.max_length:
122
- start = random.randint(0, len(tokens) - self.max_length)
123
- tokens = tokens[start : start + self.max_length - 1]
124
-
125
- tokens = (
126
- [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
127
- )
128
- # Pad dims
129
- placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
130
-
131
- tokens = torch.concat(
132
- [
133
- torch.tensor([tokens], dtype=torch.long),
134
- placeholder_multi_codebook,
135
- ],
136
- dim=0,
137
- )
138
- labels = tokens.clone()
139
- tokens = tokens[:, :-1]
140
- labels = labels[:, 1:]
141
- labels[1:] = -100 # remove all placeholders
142
-
143
- yield {"tokens": tokens, "labels": labels}
144
-
145
- def parse_data_internal(self, filename: str):
146
- url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
147
-
148
- with xopen(url, mode="rb") as stream:
149
- parquet_file = pq.ParquetFile(stream)
150
-
151
- for batch in parquet_file.iter_batches(
152
- batch_size=self.parquet_batch_size, columns=["text"]
153
- ):
154
- # In-batch shuffling
155
- texts = [{"text": text.as_py()} for text in batch["text"]]
156
- random.shuffle(texts)
157
- yield from texts
158
-
159
-
160
- class AutoAugTextDataset(IterableDataset):
161
- """
162
- Auto Augment Dataset by Speaker
163
-
164
- 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
165
- 2. Automatically normalize the text
166
-
167
- For interactive mode, we use the following format (multiple sequences):
168
- <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
169
-
170
- For non-interactive mode, we use the following format (one long sequence):
171
- <s> [INST] text [/INST] ... </s>
172
- """
173
-
174
- def __init__(
175
- self,
176
- proto_files: list[str],
177
- seed: int = 42,
178
- interactive_prob: float = 0.5,
179
- max_length: int = 1024,
180
- tokenizer: AutoTokenizer = None,
181
- use_speaker: bool = True,
182
- causual: bool = True,
183
- use_negative_samples: bool = False,
184
- num_codebooks: Optional[int] = None,
185
- ):
186
- """
187
- Args:
188
- proto_files: proto buf files if using local data
189
- seed: random seed
190
- interactive_prob: probability to use interactive mode
191
- max_length: max length of the text
192
- tokenizer: tokenizer
193
- use_speaker: include speaker information in the prompt
194
- causual: use causual sampling when using local data, disable will lead to random sampling
195
- use_negative_samples: generate negative samples
196
- num_codebooks: number of codebooks, if None, it will be automatically detected
197
- """
198
-
199
- super().__init__()
200
-
201
- assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
202
-
203
- self.seed = seed
204
- self.max_length = max_length
205
- self.tokenizer = tokenizer
206
- self.interactive_prob = interactive_prob
207
- self.use_speaker = use_speaker
208
- self.proto_files = proto_files
209
- self.causual = causual
210
- self.use_negative_samples = use_negative_samples
211
- self.num_codebooks = num_codebooks
212
-
213
- self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
214
- self.groups = None
215
-
216
- def init_mock_data_server(self):
217
- if self.groups is not None:
218
- return
219
-
220
- # Expand the proto files
221
- expanded_proto_files = []
222
- for filename in self.proto_files:
223
- for i in braceexpand(filename):
224
- i = Path(i)
225
- if i.is_file():
226
- expanded_proto_files.append(i)
227
- elif i.is_dir():
228
- expanded_proto_files.extend(i.rglob("*.proto"))
229
- expanded_proto_files.extend(i.rglob("*.protos"))
230
- else:
231
- raise ValueError(f"{i} is not a file or directory")
232
-
233
- expanded_proto_files = sorted(expanded_proto_files)
234
- Random(self.seed).shuffle(expanded_proto_files)
235
-
236
- self.groups = []
237
- shard_proto_files = split_by_rank_worker(expanded_proto_files)
238
- log.info(
239
- f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
240
- )
241
-
242
- count = 0
243
- for filename in shard_proto_files:
244
- with open(filename, "rb") as f:
245
- for text_data in read_pb_stream(f):
246
- self.groups.append(text_data)
247
- count += 1
248
-
249
- log.info(f"Read total {count} groups of data")
250
-
251
- # Shuffle the lines
252
- Random(self.seed).shuffle(self.groups)
253
- self.group_weights = [len(i.sentences) for i in self.groups]
254
-
255
- def __iter__(self):
256
- while True:
257
- yield self.augment()
258
-
259
- def tokenize_sentence(self, sentence: str):
260
- sentence = clean_text(sentence)
261
- tokens = self.tokenizer.encode(
262
- f"{sentence}",
263
- max_length=10**6,
264
- add_special_tokens=False,
265
- truncation=False,
266
- )
267
- return sentence, len(tokens)
268
-
269
- def sample_data(self):
270
- if self.groups is None:
271
- self.init_mock_data_server()
272
-
273
- # Shuffle unique lines, estimate that each sample is at least 20 tokens
274
- num_samples = self.max_length // 20
275
-
276
- # choice group based on their number of samples
277
- group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
278
-
279
- if self.causual:
280
- # Sample in order
281
- if num_samples >= len(group.sentences):
282
- samples = group.sentences
283
- else:
284
- begin = random.randint(0, len(group.sentences) - num_samples)
285
- samples = group.sentences[begin : begin + num_samples]
286
- else:
287
- samples = random.choices(
288
- group.sentences, k=min(num_samples, len(group.sentences))
289
- )
290
-
291
- return SampledData(
292
- source=group.source,
293
- name=group.name,
294
- samples=samples,
295
- )
296
-
297
- def augment(self):
298
- # Random sample based on speaker using a truncated normal distribution
299
- a = torch.tensor([0], dtype=torch.float32)
300
- torch.nn.init.trunc_normal_(
301
- a,
302
- mean=self.max_length // 2,
303
- std=self.max_length // 4,
304
- a=10,
305
- b=self.max_length,
306
- )
307
- remaining_tokens = a.long().item() - 4
308
-
309
- final_text, final_semantic = [], []
310
- response = self.sample_data()
311
- if len(response.samples) == 0:
312
- # Invalid group
313
- return None
314
-
315
- samples = list(response.samples)
316
- idx = 0
317
- use_interactive = random.random() < self.interactive_prob
318
-
319
- all_tokens, all_labels = [], []
320
- while remaining_tokens > 0 and len(samples) > 0:
321
- sentence = samples.pop(0)
322
-
323
- text = random.choice(sentence.texts)
324
- text, length = self.tokenize_sentence(text)
325
- remaining_tokens -= length + len(sentence.semantics[0].values)
326
-
327
- if use_interactive is False:
328
- final_text.append(text)
329
- final_semantic.append(sentence.semantics)
330
- else:
331
- # For interactive mode, we only apply speaker for the first sentence
332
- # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
333
- tokens, labels = self.pack_sentences(
334
- sentences=[text],
335
- semantics=[sentence.semantics],
336
- speaker=response.name if (self.use_speaker and idx == 0) else None,
337
- add_bos=idx == 0,
338
- )
339
-
340
- all_tokens.append(tokens)
341
- all_labels.append(labels)
342
-
343
- idx += 1
344
-
345
- if use_interactive is False:
346
- tokens, labels = self.pack_sentences(
347
- final_text,
348
- semantics=final_semantic,
349
- speaker=response.name if self.use_speaker else None,
350
- add_bos=True,
351
- )
352
- all_tokens.append(tokens)
353
- all_labels.append(labels)
354
-
355
- tokens = torch.cat(all_tokens, dim=1)
356
- labels = torch.cat(all_labels, dim=1)
357
-
358
- # Verify that the length is correct
359
- assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
360
-
361
- # Verify bos token
362
- assert tokens[0, 0] == self.tokenizer.bos_token_id
363
-
364
- data = {"tokens": tokens, "labels": labels}
365
-
366
- if self.use_negative_samples:
367
- negative_samples = self.generate_negative_samples(all_tokens, all_labels)
368
- data.update(negative_samples)
369
-
370
- return data
371
-
372
- def generate_negative_samples(self, all_tokens, all_labels):
373
- new_tokens, new_labels = [], []
374
-
375
- for tokens, labels in zip(all_tokens, all_labels):
376
- # If all codebooks are not -100, we find where it starts
377
- start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
378
- assert (labels[1:, start:] != -100).all() # This shouldn't happen
379
-
380
- mode = random.choice(["repeat", "lost", "noise"])
381
- begin = random.randint(start, labels.size(1) - 1)
382
- end = random.randint(begin, labels.size(1) - 1)
383
-
384
- if mode == "repeat":
385
- tokens = torch.cat(
386
- [
387
- tokens[:, :begin],
388
- tokens[:, begin:end],
389
- tokens[:, begin:end],
390
- tokens[:, end:],
391
- ],
392
- dim=1,
393
- )
394
- labels = torch.cat(
395
- [
396
- labels[:, :begin],
397
- labels[:, begin:end],
398
- labels[:, begin:end],
399
- labels[:, end:],
400
- ],
401
- dim=1,
402
- )
403
- elif mode == "lost":
404
- tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
405
- labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
406
- elif mode == "noise":
407
- middle_tokens, middle_labels = (
408
- tokens[:, begin:end],
409
- labels[:, begin:end],
410
- )
411
- random_order0 = torch.randperm(middle_tokens.size(1))
412
- random_order1 = torch.randperm(middle_tokens.size(1))
413
- middle_tokens = middle_tokens[:, random_order0]
414
- middle_labels = middle_labels[:, random_order1]
415
- tokens = torch.cat(
416
- [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
417
- )
418
- labels = torch.cat(
419
- [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
420
- )
421
-
422
- new_tokens.append(tokens)
423
- new_labels.append(labels)
424
-
425
- tokens = torch.cat(new_tokens, dim=1)
426
- labels = torch.cat(new_labels, dim=1)
427
-
428
- # Verify that the length is correct
429
- assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
430
-
431
- return {"negative_tokens": tokens, "negative_labels": labels}
432
-
433
- def pack_sentences(
434
- self,
435
- sentences: list[str],
436
- semantics=list,
437
- speaker: Optional[str] = None,
438
- add_bos: bool = True,
439
- ):
440
- if speaker is not None:
441
- sentences = [f"[SPK: {speaker}]"] + sentences
442
-
443
- final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>"
444
- final_text = final_text + "<|im_start|>assistant<|im_sep|>"
445
-
446
- encoded = self.tokenizer.encode(
447
- final_text,
448
- add_special_tokens=False,
449
- truncation=False,
450
- max_length=10**6,
451
- )
452
- semantic_length = sum([len(i[0].values) for i in semantics])
453
- prompt_length = len(encoded)
454
- num_codebooks = (
455
- len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
456
- )
457
-
458
- bos_bias = 1 if add_bos else 0
459
-
460
- # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
461
- tokens = (
462
- encoded
463
- + [self.semantic_token_id] * semantic_length
464
- + self.tokenizer.convert_tokens_to_ids(
465
- ["<|im_end|>", "<|end_of_sequence|>"]
466
- )
467
- )
468
-
469
- if add_bos:
470
- tokens = [self.tokenizer.bos_token_id] + tokens
471
-
472
- # Codebook bos/padding: 0, eos: 1
473
- codes = [
474
- [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
475
- for _ in range(num_codebooks)
476
- ]
477
- for segment in semantics:
478
- for book_idx, book in zip(range(num_codebooks), segment):
479
- for j in book.values:
480
- codes[book_idx].append(int(j) + 2)
481
-
482
- for book in codes:
483
- book.extend([CODEBOOK_EOS_TOKEN_ID] * 2)
484
-
485
- tokens = [tokens] + codes
486
-
487
- tokens = torch.tensor(tokens, dtype=torch.long)
488
- labels = tokens.clone()
489
-
490
- # Mask out the <s> tokens for semantic, predict semantic tokens only
491
- # Since we don't mask out the input tokens, the language modeling still works
492
- labels[1:, : (prompt_length + bos_bias)] = -100
493
-
494
- tokens = tokens[:, :-1]
495
- labels = labels[:, 1:]
496
-
497
- # Verify the padding is correct, and the last token is eos
498
- assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
499
- assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
500
- assert labels[0, -1] == self.tokenizer.eos_token_id
501
- assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all()
502
-
503
- return tokens, labels
504
-
505
-
506
- @dataclass
507
- class TextDataCollator:
508
- tokenizer: AutoTokenizer
509
- max_length: int = 1024
510
-
511
- def __call__(self, examples):
512
- if "negative_tokens" in examples:
513
- positive_examples = []
514
- negative_examples = []
515
-
516
- for i in examples:
517
- positive_examples.append(
518
- {
519
- "tokens": i["tokens"],
520
- "labels": i["labels"],
521
- }
522
- )
523
- negative_examples.append(
524
- {
525
- "tokens": i["negative_tokens"],
526
- "labels": i["negative_labels"],
527
- }
528
- )
529
-
530
- examples = positive_examples + negative_examples
531
-
532
- return self.batchify(examples)
533
-
534
- def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
535
- tokens, attention_masks, labels = [], [], []
536
-
537
- # Calculate the max length
538
- max_tokens_length = 0
539
- for example in examples:
540
- max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
541
- max_tokens_length = min(max_tokens_length, self.max_length)
542
-
543
- for example in examples:
544
- _tokens = example[tokens_key][:, :max_tokens_length]
545
- _labels = example[labels_key][:, :max_tokens_length]
546
- _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
547
- tokens_length = _tokens.size(1)
548
- _attention_mask[:tokens_length] = False
549
-
550
- assert tokens_length == _labels.size(
551
- 1
552
- ), f"{tokens_length} != {_labels.size(1)}"
553
-
554
- if tokens_length < max_tokens_length:
555
- _tokens = F.pad(
556
- _tokens,
557
- (0, max_tokens_length - tokens_length),
558
- value=self.tokenizer.eos_token_id,
559
- )
560
- _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
561
- _labels = F.pad(
562
- _labels, (0, max_tokens_length - _labels.size(1)), value=-100
563
- )
564
-
565
- tokens.append(_tokens)
566
- attention_masks.append(_attention_mask)
567
- labels.append(_labels)
568
-
569
- tokens = torch.stack(tokens, dim=0)
570
- attention_masks = torch.stack(attention_masks, dim=0)
571
- labels = torch.stack(labels, dim=0)
572
-
573
- return {
574
- "inputs": tokens,
575
- "attention_masks": attention_masks,
576
- "labels": labels,
577
- }
578
-
579
-
580
- class InterleaveDataset(IterableDataset):
581
- def __init__(
582
- self,
583
- datasets: list[IterableDataset],
584
- probabilities: list[float],
585
- seed: int = 42,
586
- ):
587
- super().__init__()
588
-
589
- self.datasets = datasets
590
- self.probabilities = probabilities
591
- self.seed = seed
592
-
593
- def __iter__(self):
594
- rng = np.random.default_rng(self.seed)
595
- dataset_iterators = [iter(dataset) for dataset in self.datasets]
596
-
597
- while True:
598
- # Random choice one
599
- dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
600
- dataset_iterator = dataset_iterators[dataset_idx]
601
-
602
- try:
603
- yield next(dataset_iterator)
604
- except StopIteration:
605
- # Exhausted, create a new iterator
606
- dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
607
- yield next(dataset_iterators[dataset_idx])
608
-
609
-
610
- class TextDataModule(LightningDataModule):
611
- def __init__(
612
- self,
613
- train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
614
- val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
615
- batch_size: int = 32,
616
- tokenizer: AutoTokenizer = None,
617
- max_length: int = 1024,
618
- num_workers: int = 4,
619
- ):
620
- super().__init__()
621
-
622
- self.train_dataset = train_dataset
623
- self.val_dataset = val_dataset
624
- self.batch_size = batch_size
625
- self.tokenizer = tokenizer
626
- self.max_length = max_length
627
- self.num_workers = num_workers
628
-
629
- def train_dataloader(self):
630
- return DataLoader(
631
- self.train_dataset,
632
- batch_size=self.batch_size,
633
- collate_fn=TextDataCollator(self.tokenizer, self.max_length),
634
- num_workers=self.num_workers,
635
- )
636
-
637
- def val_dataloader(self):
638
- return DataLoader(
639
- self.val_dataset,
640
- batch_size=self.batch_size,
641
- collate_fn=TextDataCollator(self.tokenizer, self.max_length),
642
- num_workers=self.num_workers,
643
- )
644
-
645
-
646
- if __name__ == "__main__":
647
- from tqdm import tqdm
648
-
649
- ds = AutoAugTextDataset(
650
- ["data/protos"],
651
- tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
652
- use_speaker=False,
653
- interactive_prob=1.0,
654
- use_negative_samples=False,
655
- )
656
-
657
- for i in ds:
658
- print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
659
- # i["labels"][0][i["labels"][0] == -100] = 0
660
- # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
661
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/datasets/vqgan.py CHANGED
@@ -1,147 +1,147 @@
1
- from dataclasses import dataclass
2
- from pathlib import Path
3
- from typing import Optional
4
-
5
- import librosa
6
- import numpy as np
7
- import torch
8
- from lightning import LightningDataModule
9
- from torch.utils.data import DataLoader, Dataset
10
-
11
- from fish_speech.utils import RankedLogger
12
-
13
- logger = RankedLogger(__name__, rank_zero_only=False)
14
-
15
-
16
- class VQGANDataset(Dataset):
17
- def __init__(
18
- self,
19
- filelist: str,
20
- sample_rate: int = 32000,
21
- hop_length: int = 640,
22
- slice_frames: Optional[int] = None,
23
- ):
24
- super().__init__()
25
-
26
- filelist = Path(filelist)
27
- root = filelist.parent
28
-
29
- self.files = [
30
- root / line.strip()
31
- for line in filelist.read_text(encoding="utf-8").splitlines()
32
- if line.strip()
33
- ]
34
- self.sample_rate = sample_rate
35
- self.hop_length = hop_length
36
- self.slice_frames = slice_frames
37
-
38
- def __len__(self):
39
- return len(self.files)
40
-
41
- def get_item(self, idx):
42
- file = self.files[idx]
43
-
44
- audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
45
-
46
- # Slice audio and features
47
- if (
48
- self.slice_frames is not None
49
- and audio.shape[0] > self.slice_frames * self.hop_length
50
- ):
51
- start = np.random.randint(
52
- 0, audio.shape[0] - self.slice_frames * self.hop_length
53
- )
54
- audio = audio[start : start + self.slice_frames * self.hop_length]
55
-
56
- if len(audio) == 0:
57
- return None
58
-
59
- max_value = np.abs(audio).max()
60
- if max_value > 1.0:
61
- audio = audio / max_value
62
-
63
- return {
64
- "audio": torch.from_numpy(audio),
65
- }
66
-
67
- def __getitem__(self, idx):
68
- try:
69
- return self.get_item(idx)
70
- except Exception as e:
71
- import traceback
72
-
73
- traceback.print_exc()
74
- logger.error(f"Error loading {self.files[idx]}: {e}")
75
- return None
76
-
77
-
78
- @dataclass
79
- class VQGANCollator:
80
- def __call__(self, batch):
81
- batch = [x for x in batch if x is not None]
82
-
83
- audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
84
- audio_maxlen = audio_lengths.max()
85
-
86
- # Rounds up to nearest multiple of 2 (audio_lengths)
87
- audios = []
88
- for x in batch:
89
- audios.append(
90
- torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
91
- )
92
-
93
- return {
94
- "audios": torch.stack(audios),
95
- "audio_lengths": audio_lengths,
96
- }
97
-
98
-
99
- class VQGANDataModule(LightningDataModule):
100
- def __init__(
101
- self,
102
- train_dataset: VQGANDataset,
103
- val_dataset: VQGANDataset,
104
- batch_size: int = 32,
105
- num_workers: int = 4,
106
- val_batch_size: Optional[int] = None,
107
- ):
108
- super().__init__()
109
-
110
- self.train_dataset = train_dataset
111
- self.val_dataset = val_dataset
112
- self.batch_size = batch_size
113
- self.val_batch_size = val_batch_size or batch_size
114
- self.num_workers = num_workers
115
-
116
- def train_dataloader(self):
117
- return DataLoader(
118
- self.train_dataset,
119
- batch_size=self.batch_size,
120
- collate_fn=VQGANCollator(),
121
- num_workers=self.num_workers,
122
- shuffle=True,
123
- persistent_workers=True,
124
- )
125
-
126
- def val_dataloader(self):
127
- return DataLoader(
128
- self.val_dataset,
129
- batch_size=self.val_batch_size,
130
- collate_fn=VQGANCollator(),
131
- num_workers=self.num_workers,
132
- persistent_workers=True,
133
- )
134
-
135
-
136
- if __name__ == "__main__":
137
- dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
138
- dataloader = DataLoader(
139
- dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
140
- )
141
-
142
- for batch in dataloader:
143
- print(batch["audios"].shape)
144
- print(batch["features"].shape)
145
- print(batch["audio_lengths"])
146
- print(batch["feature_lengths"])
147
- break
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ from lightning import LightningDataModule
9
+ from torch.utils.data import DataLoader, Dataset
10
+
11
+ from fish_speech.utils import RankedLogger
12
+
13
+ logger = RankedLogger(__name__, rank_zero_only=False)
14
+
15
+
16
+ class VQGANDataset(Dataset):
17
+ def __init__(
18
+ self,
19
+ filelist: str,
20
+ sample_rate: int = 32000,
21
+ hop_length: int = 640,
22
+ slice_frames: Optional[int] = None,
23
+ ):
24
+ super().__init__()
25
+
26
+ filelist = Path(filelist)
27
+ root = filelist.parent
28
+
29
+ self.files = [
30
+ root / line.strip()
31
+ for line in filelist.read_text(encoding="utf-8").splitlines()
32
+ if line.strip()
33
+ ]
34
+ self.sample_rate = sample_rate
35
+ self.hop_length = hop_length
36
+ self.slice_frames = slice_frames
37
+
38
+ def __len__(self):
39
+ return len(self.files)
40
+
41
+ def get_item(self, idx):
42
+ file = self.files[idx]
43
+
44
+ audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
45
+
46
+ # Slice audio and features
47
+ if (
48
+ self.slice_frames is not None
49
+ and audio.shape[0] > self.slice_frames * self.hop_length
50
+ ):
51
+ start = np.random.randint(
52
+ 0, audio.shape[0] - self.slice_frames * self.hop_length
53
+ )
54
+ audio = audio[start : start + self.slice_frames * self.hop_length]
55
+
56
+ if len(audio) == 0:
57
+ return None
58
+
59
+ max_value = np.abs(audio).max()
60
+ if max_value > 1.0:
61
+ audio = audio / max_value
62
+
63
+ return {
64
+ "audio": torch.from_numpy(audio),
65
+ }
66
+
67
+ def __getitem__(self, idx):
68
+ try:
69
+ return self.get_item(idx)
70
+ except Exception as e:
71
+ import traceback
72
+
73
+ traceback.print_exc()
74
+ logger.error(f"Error loading {self.files[idx]}: {e}")
75
+ return None
76
+
77
+
78
+ @dataclass
79
+ class VQGANCollator:
80
+ def __call__(self, batch):
81
+ batch = [x for x in batch if x is not None]
82
+
83
+ audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
84
+ audio_maxlen = audio_lengths.max()
85
+
86
+ # Rounds up to nearest multiple of 2 (audio_lengths)
87
+ audios = []
88
+ for x in batch:
89
+ audios.append(
90
+ torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
91
+ )
92
+
93
+ return {
94
+ "audios": torch.stack(audios),
95
+ "audio_lengths": audio_lengths,
96
+ }
97
+
98
+
99
+ class VQGANDataModule(LightningDataModule):
100
+ def __init__(
101
+ self,
102
+ train_dataset: VQGANDataset,
103
+ val_dataset: VQGANDataset,
104
+ batch_size: int = 32,
105
+ num_workers: int = 4,
106
+ val_batch_size: Optional[int] = None,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.train_dataset = train_dataset
111
+ self.val_dataset = val_dataset
112
+ self.batch_size = batch_size
113
+ self.val_batch_size = val_batch_size or batch_size
114
+ self.num_workers = num_workers
115
+
116
+ def train_dataloader(self):
117
+ return DataLoader(
118
+ self.train_dataset,
119
+ batch_size=self.batch_size,
120
+ collate_fn=VQGANCollator(),
121
+ num_workers=self.num_workers,
122
+ shuffle=True,
123
+ persistent_workers=True,
124
+ )
125
+
126
+ def val_dataloader(self):
127
+ return DataLoader(
128
+ self.val_dataset,
129
+ batch_size=self.val_batch_size,
130
+ collate_fn=VQGANCollator(),
131
+ num_workers=self.num_workers,
132
+ persistent_workers=True,
133
+ )
134
+
135
+
136
+ if __name__ == "__main__":
137
+ dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
138
+ dataloader = DataLoader(
139
+ dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
140
+ )
141
+
142
+ for batch in dataloader:
143
+ print(batch["audios"].shape)
144
+ print(batch["features"].shape)
145
+ print(batch["audio_lengths"])
146
+ print(batch["feature_lengths"])
147
+ break
fish_speech/i18n/README.md CHANGED
@@ -1,27 +1,27 @@
1
- ## i18n Folder Attribution
2
-
3
- The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
4
-
5
- ### fish_speech/i18n/core.py
6
-
7
- **Related code from RVC:**
8
- [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
9
-
10
- **Initial commit:**
11
- add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
12
-
13
- **Initial author:**
14
- [@L4Ph](https://github.com/L4Ph)
15
-
16
- ### fish_speech/i18n/scan.py
17
-
18
- **Related code from RVC:**
19
- [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
20
-
21
- **Initial commit:**
22
- File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
23
-
24
- **Initial author:**
25
- [@towzeur](https://github.com/towzeur)
26
-
27
- We appreciate the contributions of the RVC project and its authors.
 
1
+ ## i18n Folder Attribution
2
+
3
+ The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
4
+
5
+ ### fish_speech/i18n/core.py
6
+
7
+ **Related code from RVC:**
8
+ [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
9
+
10
+ **Initial commit:**
11
+ add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
12
+
13
+ **Initial author:**
14
+ [@L4Ph](https://github.com/L4Ph)
15
+
16
+ ### fish_speech/i18n/scan.py
17
+
18
+ **Related code from RVC:**
19
+ [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
20
+
21
+ **Initial commit:**
22
+ File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
23
+
24
+ **Initial author:**
25
+ [@towzeur](https://github.com/towzeur)
26
+
27
+ We appreciate the contributions of the RVC project and its authors.
fish_speech/i18n/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from .core import i18n
2
-
3
- __all__ = ["i18n"]
 
1
+ from .core import i18n
2
+
3
+ __all__ = ["i18n"]
fish_speech/i18n/core.py CHANGED
@@ -1,40 +1,40 @@
1
- import json
2
- import locale
3
- from pathlib import Path
4
-
5
- I18N_FILE_PATH = Path(__file__).parent / "locale"
6
- DEFAULT_LANGUAGE = "en_US"
7
-
8
-
9
- def load_language_list(language):
10
- with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
11
- language_list = json.load(f)
12
-
13
- return language_list
14
-
15
-
16
- class I18nAuto:
17
- def __init__(self):
18
- i18n_file = Path(".locale")
19
-
20
- if i18n_file.exists():
21
- with open(i18n_file, "r", encoding="utf-8") as f:
22
- language = f.read().strip()
23
- else:
24
- # getlocale can't identify the system's language ((None, None))
25
- language = locale.getdefaultlocale()[0]
26
-
27
- if (I18N_FILE_PATH / f"{language}.json").exists() is False:
28
- language = DEFAULT_LANGUAGE
29
-
30
- self.language = language
31
- self.language_map = load_language_list(language)
32
-
33
- def __call__(self, key):
34
- return self.language_map.get(key, key)
35
-
36
- def __repr__(self):
37
- return "Use Language: " + self.language
38
-
39
-
40
- i18n = I18nAuto()
 
1
+ import json
2
+ import locale
3
+ from pathlib import Path
4
+
5
+ I18N_FILE_PATH = Path(__file__).parent / "locale"
6
+ DEFAULT_LANGUAGE = "en_US"
7
+
8
+
9
+ def load_language_list(language):
10
+ with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
11
+ language_list = json.load(f)
12
+
13
+ return language_list
14
+
15
+
16
+ class I18nAuto:
17
+ def __init__(self):
18
+ i18n_file = Path(".locale")
19
+
20
+ if i18n_file.exists():
21
+ with open(i18n_file, "r", encoding="utf-8") as f:
22
+ language = f.read().strip()
23
+ else:
24
+ # getlocale can't identify the system's language ((None, None))
25
+ language = locale.getdefaultlocale()[0]
26
+
27
+ if (I18N_FILE_PATH / f"{language}.json").exists() is False:
28
+ language = DEFAULT_LANGUAGE
29
+
30
+ self.language = language
31
+ self.language_map = load_language_list(language)
32
+
33
+ def __call__(self, key):
34
+ return self.language_map.get(key, key)
35
+
36
+ def __repr__(self):
37
+ return "Use Language: " + self.language
38
+
39
+
40
+ i18n = I18nAuto()
fish_speech/i18n/locale/en_US.json CHANGED
@@ -1,122 +1,123 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
5
- "Accumulate Gradient Batches": "Accumulate Gradient Batches",
6
- "Add to Processing Area": "Add to Processing Area",
7
- "Added path successfully!": "Added path successfully!",
8
- "Advanced Config": "Advanced Config",
9
- "Base LLAMA Model": "Base LLAMA Model",
10
- "Batch Inference": "Batch Inference",
11
- "Batch Size": "Batch Size",
12
- "Changing with the Model Path": "Changing with the Model Path",
13
- "Chinese": "Chinese",
14
- "Compile Model": "Compile Model",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
16
- "Copy": "Copy",
17
- "Data Preprocessing": "Data Preprocessing",
18
- "Data Preprocessing Path": "Data Preprocessing Path",
19
- "Data Source": "Data Source",
20
- "Decoder Model Config": "Decoder Model Config",
21
- "Decoder Model Path": "Decoder Model Path",
22
- "Disabled": "Disabled",
23
- "Enable Reference Audio": "Enable Reference Audio",
24
- "English": "English",
25
- "Error Message": "Error Message",
26
- "File Preprocessing": "File Preprocessing",
27
- "Generate": "Generate",
28
- "Generated Audio": "Generated Audio",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
30
- "Infer interface is closed": "Infer interface is closed",
31
- "Inference Configuration": "Inference Configuration",
32
- "Inference Server Configuration": "Inference Server Configuration",
33
- "Inference Server Error": "Inference Server Error",
34
- "Inferring interface is launched at {}": "Inferring interface is launched at {}",
35
- "Initial Learning Rate": "Initial Learning Rate",
36
- "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
37
- "Input Text": "Input Text",
38
- "Invalid path: {}": "Invalid path: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
40
- "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
41
- "Japanese": "Japanese",
42
- "LLAMA Configuration": "LLAMA Configuration",
43
- "LLAMA Model Config": "LLAMA Model Config",
44
- "LLAMA Model Path": "LLAMA Model Path",
45
- "Labeling Device": "Labeling Device",
46
- "LoRA Model to be merged": "LoRA Model to be merged",
47
- "Maximum Audio Duration": "Maximum Audio Duration",
48
- "Maximum Length per Sample": "Maximum Length per Sample",
49
- "Maximum Training Steps": "Maximum Training Steps",
50
- "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
51
- "Merge": "Merge",
52
- "Merge LoRA": "Merge LoRA",
53
- "Merge successfully": "Merge successfully",
54
- "Minimum Audio Duration": "Minimum Audio Duration",
55
- "Model Output Path": "Model Output Path",
56
- "Model Size": "Model Size",
57
- "Move": "Move",
58
- "Move files successfully": "Move files successfully",
59
- "No audio generated, please check the input text.": "No audio generated, please check the input text.",
60
- "No selected options": "No selected options",
61
- "Number of Workers": "Number of Workers",
62
- "Open Inference Server": "Open Inference Server",
63
- "Open Labeler WebUI": "Open Labeler WebUI",
64
- "Open Tensorboard": "Open Tensorboard",
65
- "Opened labeler in browser": "Opened labeler in browser",
66
- "Optional Label Language": "Optional Label Language",
67
- "Optional online ver": "Optional online ver",
68
- "Output Path": "Output Path",
69
- "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
70
- "Precision": "Precision",
71
- "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
72
- "Put your text here.": "Put your text here.",
73
- "Reference Audio": "Reference Audio",
74
- "Reference Text": "Reference Text",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
76
- "Remove Selected Data": "Remove Selected Data",
77
- "Removed path successfully!": "Removed path successfully!",
78
- "Repetition Penalty": "Repetition Penalty",
79
- "Save model every n steps": "Save model every n steps",
80
- "Select LLAMA ckpt": "Select LLAMA ckpt",
81
- "Select VITS ckpt": "Select VITS ckpt",
82
- "Select VQGAN ckpt": "Select VQGAN ckpt",
83
- "Select source file processing method": "Select source file processing method",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
85
- "Selected: {}": "Selected: {}",
86
- "Speaker": "Speaker",
87
- "Speaker is identified by the folder name": "Speaker is identified by the folder name",
88
- "Start Training": "Start Training",
89
- "Streaming Audio": "Streaming Audio",
90
- "Streaming Generate": "Streaming Generate",
91
- "Tensorboard Host": "Tensorboard Host",
92
- "Tensorboard Log Path": "Tensorboard Log Path",
93
- "Tensorboard Port": "Tensorboard Port",
94
- "Tensorboard interface is closed": "Tensorboard interface is closed",
95
- "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
96
- "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
98
- "Training Configuration": "Training Configuration",
99
- "Training Error": "Training Error",
100
- "Training stopped": "Training stopped",
101
- "Type name of the speaker": "Type name of the speaker",
102
- "Type the path or select from the dropdown": "Type the path or select from the dropdown",
103
- "Use LoRA": "Use LoRA",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
105
- "Use filelist": "Use filelist",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
107
- "VITS Configuration": "VITS Configuration",
108
- "VQGAN Configuration": "VQGAN Configuration",
109
- "Validation Batch Size": "Validation Batch Size",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
112
- "WebUI Host": "WebUI Host",
113
- "WebUI Port": "WebUI Port",
114
- "Whisper Model": "Whisper Model",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
117
- "latest": "latest",
118
- "new": "new",
119
- "Realtime Transform Text": "Realtime Transform Text",
120
- "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
121
- "Text Normalization": "Text Normalization"
122
- }
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
5
+ "Accumulate Gradient Batches": "Accumulate Gradient Batches",
6
+ "Add to Processing Area": "Add to Processing Area",
7
+ "Added path successfully!": "Added path successfully!",
8
+ "Advanced Config": "Advanced Config",
9
+ "Base LLAMA Model": "Base LLAMA Model",
10
+ "Batch Inference": "Batch Inference",
11
+ "Batch Size": "Batch Size",
12
+ "Changing with the Model Path": "Changing with the Model Path",
13
+ "Chinese": "Chinese",
14
+ "Compile Model": "Compile Model",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
16
+ "Copy": "Copy",
17
+ "Data Preprocessing": "Data Preprocessing",
18
+ "Data Preprocessing Path": "Data Preprocessing Path",
19
+ "Data Source": "Data Source",
20
+ "Decoder Model Config": "Decoder Model Config",
21
+ "Decoder Model Path": "Decoder Model Path",
22
+ "Disabled": "Disabled",
23
+ "Enable Reference Audio": "Enable Reference Audio",
24
+ "English": "English",
25
+ "Error Message": "Error Message",
26
+ "File Preprocessing": "File Preprocessing",
27
+ "Generate": "Generate",
28
+ "Generated Audio": "Generated Audio",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
30
+ "Infer interface is closed": "Infer interface is closed",
31
+ "Inference Configuration": "Inference Configuration",
32
+ "Inference Server Configuration": "Inference Server Configuration",
33
+ "Inference Server Error": "Inference Server Error",
34
+ "Inferring interface is launched at {}": "Inferring interface is launched at {}",
35
+ "Initial Learning Rate": "Initial Learning Rate",
36
+ "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
37
+ "Input Text": "Input Text",
38
+ "Invalid path: {}": "Invalid path: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
40
+ "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
41
+ "Japanese": "Japanese",
42
+ "LLAMA Configuration": "LLAMA Configuration",
43
+ "LLAMA Model Config": "LLAMA Model Config",
44
+ "LLAMA Model Path": "LLAMA Model Path",
45
+ "Labeling Device": "Labeling Device",
46
+ "LoRA Model to be merged": "LoRA Model to be merged",
47
+ "Maximum Audio Duration": "Maximum Audio Duration",
48
+ "Maximum Length per Sample": "Maximum Length per Sample",
49
+ "Maximum Training Steps": "Maximum Training Steps",
50
+ "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
51
+ "Merge": "Merge",
52
+ "Merge LoRA": "Merge LoRA",
53
+ "Merge successfully": "Merge successfully",
54
+ "Minimum Audio Duration": "Minimum Audio Duration",
55
+ "Model Output Path": "Model Output Path",
56
+ "Model Size": "Model Size",
57
+ "Move": "Move",
58
+ "Move files successfully": "Move files successfully",
59
+ "No audio generated, please check the input text.": "No audio generated, please check the input text.",
60
+ "No selected options": "No selected options",
61
+ "Number of Workers": "Number of Workers",
62
+ "Open Inference Server": "Open Inference Server",
63
+ "Open Labeler WebUI": "Open Labeler WebUI",
64
+ "Open Tensorboard": "Open Tensorboard",
65
+ "Opened labeler in browser": "Opened labeler in browser",
66
+ "Optional Label Language": "Optional Label Language",
67
+ "Optional online ver": "Optional online ver",
68
+ "Output Path": "Output Path",
69
+ "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
70
+ "Precision": "Precision",
71
+ "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
72
+ "Put your text here.": "Put your text here.",
73
+ "Reference Audio": "Reference Audio",
74
+ "Reference Text": "Reference Text",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
76
+ "Remove Selected Data": "Remove Selected Data",
77
+ "Removed path successfully!": "Removed path successfully!",
78
+ "Repetition Penalty": "Repetition Penalty",
79
+ "Save model every n steps": "Save model every n steps",
80
+ "Select LLAMA ckpt": "Select LLAMA ckpt",
81
+ "Select VITS ckpt": "Select VITS ckpt",
82
+ "Select VQGAN ckpt": "Select VQGAN ckpt",
83
+ "Select source file processing method": "Select source file processing method",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
85
+ "Selected: {}": "Selected: {}",
86
+ "Speaker": "Speaker",
87
+ "Speaker is identified by the folder name": "Speaker is identified by the folder name",
88
+ "Start Training": "Start Training",
89
+ "Streaming Audio": "Streaming Audio",
90
+ "Streaming Generate": "Streaming Generate",
91
+ "Tensorboard Host": "Tensorboard Host",
92
+ "Tensorboard Log Path": "Tensorboard Log Path",
93
+ "Tensorboard Port": "Tensorboard Port",
94
+ "Tensorboard interface is closed": "Tensorboard interface is closed",
95
+ "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
96
+ "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
98
+ "Training Configuration": "Training Configuration",
99
+ "Training Error": "Training Error",
100
+ "Training stopped": "Training stopped",
101
+ "Type name of the speaker": "Type name of the speaker",
102
+ "Type the path or select from the dropdown": "Type the path or select from the dropdown",
103
+ "Use LoRA": "Use LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
105
+ "Use filelist": "Use filelist",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
107
+ "VITS Configuration": "VITS Configuration",
108
+ "VQGAN Configuration": "VQGAN Configuration",
109
+ "Validation Batch Size": "Validation Batch Size",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
112
+ "WebUI Host": "WebUI Host",
113
+ "WebUI Port": "WebUI Port",
114
+ "Whisper Model": "Whisper Model",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
117
+ "latest": "latest",
118
+ "new": "new",
119
+ "Realtime Transform Text": "Realtime Transform Text",
120
+ "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
121
+ "Text Normalization": "Text Normalization",
122
+ "Select Example Audio": "Select Example Audio"
123
+ }
fish_speech/i18n/locale/es_ES.json CHANGED
@@ -1,122 +1,123 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
5
- "Accumulate Gradient Batches": "Acumular lotes de gradientes",
6
- "Add to Processing Area": "Agregar al Área de Procesamiento",
7
- "Added path successfully!": "¡Ruta agregada exitosamente!",
8
- "Advanced Config": "Configuración Avanzada",
9
- "Base LLAMA Model": "Modelo Base LLAMA",
10
- "Batch Inference": "Inferencia por Lote",
11
- "Batch Size": "Tamaño del Lote",
12
- "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
13
- "Chinese": "Chino",
14
- "Compile Model": "Compilar Modelo",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
16
- "Copy": "Copiar",
17
- "Data Preprocessing": "Preprocesamiento de Datos",
18
- "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
19
- "Data Source": "Fuente de Datos",
20
- "Decoder Model Config": "Configuración del modelo decodificador",
21
- "Decoder Model Path": "Ruta del modelo decodificador",
22
- "Disabled": "Desactivado",
23
- "Enable Reference Audio": "Habilitar Audio de Referencia",
24
- "English": "Inglés",
25
- "Error Message": "Mensaje de Error",
26
- "File Preprocessing": "Preprocesamiento de Archivos",
27
- "Generate": "Generar",
28
- "Generated Audio": "Audio Generado",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
30
- "Infer interface is closed": "La interfaz de inferencia está cerrada",
31
- "Inference Configuration": "Configuración de Inferencia",
32
- "Inference Server Configuration": "Configuración del Servidor de Inferencia",
33
- "Inference Server Error": "Error del Servidor de Inferencia",
34
- "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
35
- "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
36
- "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
37
- "Input Text": "Texto de Entrada",
38
- "Invalid path: {}": "Ruta inválida: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
40
- "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
41
- "Japanese": "Japonés",
42
- "LLAMA Configuration": "Configuración de LLAMA",
43
- "LLAMA Model Config": "Configuración del Modelo LLAMA",
44
- "LLAMA Model Path": "Ruta del Modelo LLAMA",
45
- "Labeling Device": "Dispositivo de Etiquetado",
46
- "LoRA Model to be merged": "Modelo LoRA a fusionar",
47
- "Maximum Audio Duration": "Duración máxima de audio",
48
- "Maximum Length per Sample": "Longitud Máxima por Muestra",
49
- "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
50
- "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
51
- "Merge": "Fusionar",
52
- "Merge LoRA": "Fusionar LoRA",
53
- "Merge successfully": "Fusionado exitosamente",
54
- "Minimum Audio Duration": "Duración mínima de audio",
55
- "Model Output Path": "Ruta de Salida del Modelo",
56
- "Model Size": "Tamaño del Modelo",
57
- "Move": "Mover",
58
- "Move files successfully": "Archivos movidos exitosamente",
59
- "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
60
- "No selected options": "No hay opciones seleccionadas",
61
- "Number of Workers": "Número de Trabajadores",
62
- "Open Inference Server": "Abrir Servidor de Inferencia",
63
- "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
64
- "Open Tensorboard": "Abrir Tensorboard",
65
- "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
66
- "Optional Label Language": "Idioma de Etiquetado Opcional",
67
- "Optional online ver": "Ver en línea opcional",
68
- "Output Path": "Ruta de Salida",
69
- "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
70
- "Precision": "Precisión",
71
- "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
72
- "Put your text here.": "Ponga su texto aquí.",
73
- "Reference Audio": "Audio de Referencia",
74
- "Reference Text": "Texto de Referencia",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
76
- "Remove Selected Data": "Eliminar Datos Seleccionados",
77
- "Removed path successfully!": "¡Ruta eliminada exitosamente!",
78
- "Repetition Penalty": "Penalización por Repetición",
79
- "Save model every n steps": "Guardar modelo cada n pasos",
80
- "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
81
- "Select VITS ckpt": "Seleccionar punto de control VITS",
82
- "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
83
- "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
85
- "Selected: {}": "Seleccionado: {}",
86
- "Speaker": "Hablante",
87
- "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
88
- "Start Training": "Iniciar Entrenamiento",
89
- "Streaming Audio": "transmisión de audio",
90
- "Streaming Generate": "síntesis en flujo",
91
- "Tensorboard Host": "Host de Tensorboard",
92
- "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
93
- "Tensorboard Port": "Puerto de Tensorboard",
94
- "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
95
- "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
96
- "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
98
- "Training Configuration": "Configuración de Entrenamiento",
99
- "Training Error": "Error de Entrenamiento",
100
- "Training stopped": "Entrenamiento detenido",
101
- "Type name of the speaker": "Escriba el nombre del hablante",
102
- "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
103
- "Use LoRA": "Usar LoRA",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
105
- "Use filelist": "Usar lista de archivos",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
107
- "VITS Configuration": "Configuración de VITS",
108
- "VQGAN Configuration": "Configuración de VQGAN",
109
- "Validation Batch Size": "Tamaño del Lote de Validación",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
112
- "WebUI Host": "Host de WebUI",
113
- "WebUI Port": "Puerto de WebUI",
114
- "Whisper Model": "Modelo Whisper",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
117
- "latest": "más reciente",
118
- "new": "nuevo",
119
- "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
120
- "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
121
- "Text Normalization": "Normalización de Texto"
122
- }
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
5
+ "Accumulate Gradient Batches": "Acumular lotes de gradientes",
6
+ "Add to Processing Area": "Agregar al Área de Procesamiento",
7
+ "Added path successfully!": "¡Ruta agregada exitosamente!",
8
+ "Advanced Config": "Configuración Avanzada",
9
+ "Base LLAMA Model": "Modelo Base LLAMA",
10
+ "Batch Inference": "Inferencia por Lote",
11
+ "Batch Size": "Tamaño del Lote",
12
+ "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
13
+ "Chinese": "Chino",
14
+ "Compile Model": "Compilar Modelo",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
16
+ "Copy": "Copiar",
17
+ "Data Preprocessing": "Preprocesamiento de Datos",
18
+ "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
19
+ "Data Source": "Fuente de Datos",
20
+ "Decoder Model Config": "Configuración del modelo decodificador",
21
+ "Decoder Model Path": "Ruta del modelo decodificador",
22
+ "Disabled": "Desactivado",
23
+ "Enable Reference Audio": "Habilitar Audio de Referencia",
24
+ "English": "Inglés",
25
+ "Error Message": "Mensaje de Error",
26
+ "File Preprocessing": "Preprocesamiento de Archivos",
27
+ "Generate": "Generar",
28
+ "Generated Audio": "Audio Generado",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
30
+ "Infer interface is closed": "La interfaz de inferencia está cerrada",
31
+ "Inference Configuration": "Configuración de Inferencia",
32
+ "Inference Server Configuration": "Configuración del Servidor de Inferencia",
33
+ "Inference Server Error": "Error del Servidor de Inferencia",
34
+ "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
35
+ "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
36
+ "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
37
+ "Input Text": "Texto de Entrada",
38
+ "Invalid path: {}": "Ruta inválida: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
40
+ "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
41
+ "Japanese": "Japonés",
42
+ "LLAMA Configuration": "Configuración de LLAMA",
43
+ "LLAMA Model Config": "Configuración del Modelo LLAMA",
44
+ "LLAMA Model Path": "Ruta del Modelo LLAMA",
45
+ "Labeling Device": "Dispositivo de Etiquetado",
46
+ "LoRA Model to be merged": "Modelo LoRA a fusionar",
47
+ "Maximum Audio Duration": "Duración máxima de audio",
48
+ "Maximum Length per Sample": "Longitud Máxima por Muestra",
49
+ "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
50
+ "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
51
+ "Merge": "Fusionar",
52
+ "Merge LoRA": "Fusionar LoRA",
53
+ "Merge successfully": "Fusionado exitosamente",
54
+ "Minimum Audio Duration": "Duración mínima de audio",
55
+ "Model Output Path": "Ruta de Salida del Modelo",
56
+ "Model Size": "Tamaño del Modelo",
57
+ "Move": "Mover",
58
+ "Move files successfully": "Archivos movidos exitosamente",
59
+ "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
60
+ "No selected options": "No hay opciones seleccionadas",
61
+ "Number of Workers": "Número de Trabajadores",
62
+ "Open Inference Server": "Abrir Servidor de Inferencia",
63
+ "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
64
+ "Open Tensorboard": "Abrir Tensorboard",
65
+ "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
66
+ "Optional Label Language": "Idioma de Etiquetado Opcional",
67
+ "Optional online ver": "Ver en línea opcional",
68
+ "Output Path": "Ruta de Salida",
69
+ "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
70
+ "Precision": "Precisión",
71
+ "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
72
+ "Put your text here.": "Ponga su texto aquí.",
73
+ "Reference Audio": "Audio de Referencia",
74
+ "Reference Text": "Texto de Referencia",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
76
+ "Remove Selected Data": "Eliminar Datos Seleccionados",
77
+ "Removed path successfully!": "¡Ruta eliminada exitosamente!",
78
+ "Repetition Penalty": "Penalización por Repetición",
79
+ "Save model every n steps": "Guardar modelo cada n pasos",
80
+ "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
81
+ "Select VITS ckpt": "Seleccionar punto de control VITS",
82
+ "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
83
+ "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
85
+ "Selected: {}": "Seleccionado: {}",
86
+ "Speaker": "Hablante",
87
+ "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
88
+ "Start Training": "Iniciar Entrenamiento",
89
+ "Streaming Audio": "transmisión de audio",
90
+ "Streaming Generate": "síntesis en flujo",
91
+ "Tensorboard Host": "Host de Tensorboard",
92
+ "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
93
+ "Tensorboard Port": "Puerto de Tensorboard",
94
+ "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
95
+ "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
96
+ "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
98
+ "Training Configuration": "Configuración de Entrenamiento",
99
+ "Training Error": "Error de Entrenamiento",
100
+ "Training stopped": "Entrenamiento detenido",
101
+ "Type name of the speaker": "Escriba el nombre del hablante",
102
+ "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
103
+ "Use LoRA": "Usar LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
105
+ "Use filelist": "Usar lista de archivos",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
107
+ "VITS Configuration": "Configuración de VITS",
108
+ "VQGAN Configuration": "Configuración de VQGAN",
109
+ "Validation Batch Size": "Tamaño del Lote de Validación",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
112
+ "WebUI Host": "Host de WebUI",
113
+ "WebUI Port": "Puerto de WebUI",
114
+ "Whisper Model": "Modelo Whisper",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
117
+ "latest": "más reciente",
118
+ "new": "nuevo",
119
+ "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
120
+ "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
121
+ "Text Normalization": "Normalización de Texto",
122
+ "Select Example Audio": "Selecionar áudio de exemplo"
123
+ }
fish_speech/i18n/locale/ja_JP.json CHANGED
@@ -1,123 +1,123 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
5
- "Accumulate Gradient Batches": "勾配バッチの累積",
6
- "Add to Processing Area": "処理エリアに追加",
7
- "Added path successfully!": "パスの追加に成功しました!",
8
- "Advanced Config": "詳細設定",
9
- "Base LLAMA Model": "基本LLAMAモデル",
10
- "Batch Inference": "バッチ推論",
11
- "Batch Size": "バッチサイズ",
12
- "Changing with the Model Path": "モデルのパスに伴って変化する",
13
- "Chinese": "中国語",
14
- "Compile Model": "モデルのコンパイル",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
16
- "Copy": "コピー",
17
- "Data Preprocessing": "データ前処理",
18
- "Data Preprocessing Path": "データ前処理パス",
19
- "Data Source": "データソース",
20
- "Decoder Model Config": "デコーダーモデルの構成",
21
- "Decoder Model Path": "デコーダーモデルのパス",
22
- "Disabled": "無効",
23
- "Enable Reference Audio": "リファレンスオーディオを有効にする",
24
- "English": "英語",
25
- "Error Message": "エラーメッセージ",
26
- "File Preprocessing": "文書前处理",
27
- "Generate": "生成",
28
- "Generated Audio": "生成されたオーディオ",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
30
- "Infer interface is closed": "推論インターフェースが閉じられています",
31
- "Inference Configuration": "推論設定",
32
- "Inference Server Configuration": "推論サーバー設定",
33
- "Inference Server Error": "推論サーバーエラー",
34
- "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
35
- "Initial Learning Rate": "初期学習率",
36
- "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
37
- "Input Text": "入力テキスト",
38
- "Invalid path: {}": "無効なパス: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
40
- "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
41
- "Japanese": "日本語",
42
- "LLAMA Configuration": "LLAMA設定",
43
- "LLAMA Model Config": "LLAMAモデル設定",
44
- "LLAMA Model Path": "LLAMAモデルパス",
45
- "Labeling Device": "ラベリングデバイス",
46
- "LoRA Model to be merged": "マージするLoRAモデル",
47
- "Maximum Audio Duration": "最大オーディオの長さ",
48
- "Maximum Length per Sample": "サンプルあたりの最大長",
49
- "Maximum Training Steps": "最大トレーニングステップ数",
50
- "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
51
- "Merge": "マージ",
52
- "Merge LoRA": "LoRAのマージ",
53
- "Merge successfully": "マージに成功しました",
54
- "Minimum Audio Duration": "最小オーディオの長さ",
55
- "Model Output Path": "モデル出力パス",
56
- "Model Size": "モデルサイズ",
57
- "Move": "移動",
58
- "Move files successfully": "ファイルの移動に成功しました",
59
- "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
60
- "No selected options": "選択されたオプションはありません",
61
- "Number of Workers": "ワーカー数",
62
- "Open Inference Server": "推論サーバーを開く",
63
- "Open Labeler WebUI": "ラベラーWebUIを開く",
64
- "Open Tensorboard": "Tensorboardを開く",
65
- "Opened labeler in browser": "ブラウザでラベラーを開きました",
66
- "Optional Label Language": "オプションのラベル言語",
67
- "Optional online ver": "オプションのオンラインバージョン",
68
- "Output Path": "出力パス",
69
- "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
70
- "Precision": "精度",
71
- "Probability of applying Speaker Condition": "話者条件を適用する確率",
72
- "Put your text here.": "ここにテキストを入力してください。",
73
- "Reference Audio": "リファレン���オーディオ",
74
- "Reference Text": "リファレンステキスト",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
76
- "Remove Selected Data": "選択したデータを削除",
77
- "Removed path successfully!": "パスの削除に成功しました!",
78
- "Repetition Penalty": "反復ペナルティ",
79
- "Save model every n steps": "nステップごとにモデルを保存",
80
- "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
81
- "Select VITS ckpt": "VITS チェックポイントを選択",
82
- "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
83
- "Select source file processing method": "ソースファイルの処理方法を選択",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
85
- "Selected: {}": "選択済み: {}",
86
- "Speaker": "話者",
87
- "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
88
- "Start Training": "トレーニング開始",
89
- "Streaming Audio": "ストリーミングオーディオ",
90
- "Streaming Generate": "ストリーミング合成",
91
- "Tensorboard Host": "Tensorboardホスト",
92
- "Tensorboard Log Path": "Tensorboardログパス",
93
- "Tensorboard Port": "Tensorboardポート",
94
- "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
95
- "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
96
- "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
98
- "Training Configuration": "トレーニング設定",
99
- "Training Error": "トレーニングエラー",
100
- "Training stopped": "トレーニングが停止しました",
101
- "Type name of the speaker": "話者の名前を入力",
102
- "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
103
- "Use LoRA": "LoRAを使用",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
105
- "Use filelist": "ファイルリストを使用",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
107
- "VITS Configuration": "VITS の構成",
108
- "VQGAN Configuration": "VQGAN の構成",
109
- "Validation Batch Size": "検証バッチサイズ",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
112
- "WebUI Host": "WebUIホスト",
113
- "WebUI Port": "WebUIポート",
114
- "Whisper Model": "Whisperモデル",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
117
- "latest": "最新",
118
- "new": "新規",
119
- "Realtime Transform Text": "リアルタイム変換テキスト",
120
- "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
121
- "Text Normalization": "テキスト正規化"
122
-
123
- }
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
5
+ "Accumulate Gradient Batches": "勾配バッチの累積",
6
+ "Add to Processing Area": "処理エリアに追加",
7
+ "Added path successfully!": "パスの追加に成功しました!",
8
+ "Advanced Config": "詳細設定",
9
+ "Base LLAMA Model": "基本LLAMAモデル",
10
+ "Batch Inference": "バッチ推論",
11
+ "Batch Size": "バッチサイズ",
12
+ "Changing with the Model Path": "モデルのパスに伴って変化する",
13
+ "Chinese": "中国語",
14
+ "Compile Model": "モデルのコンパイル",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
16
+ "Copy": "コピー",
17
+ "Data Preprocessing": "データ前処理",
18
+ "Data Preprocessing Path": "データ前処理パス",
19
+ "Data Source": "データソース",
20
+ "Decoder Model Config": "デコーダーモデルの構成",
21
+ "Decoder Model Path": "デコーダーモデルのパス",
22
+ "Disabled": "無効",
23
+ "Enable Reference Audio": "リファレンスオーディオを有効にする",
24
+ "English": "英語",
25
+ "Error Message": "エラーメッセージ",
26
+ "File Preprocessing": "文書前处理",
27
+ "Generate": "生成",
28
+ "Generated Audio": "生成されたオーディオ",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
30
+ "Infer interface is closed": "推論インターフェースが閉じられています",
31
+ "Inference Configuration": "推論設定",
32
+ "Inference Server Configuration": "推論サーバー設定",
33
+ "Inference Server Error": "推論サーバーエラー",
34
+ "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
35
+ "Initial Learning Rate": "初期学習率",
36
+ "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
37
+ "Input Text": "入力テキスト",
38
+ "Invalid path: {}": "無効なパス: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
40
+ "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
41
+ "Japanese": "日本語",
42
+ "LLAMA Configuration": "LLAMA設定",
43
+ "LLAMA Model Config": "LLAMAモデル設定",
44
+ "LLAMA Model Path": "LLAMAモデルパス",
45
+ "Labeling Device": "ラベリングデバイス",
46
+ "LoRA Model to be merged": "マージするLoRAモデル",
47
+ "Maximum Audio Duration": "最大オーディオの長さ",
48
+ "Maximum Length per Sample": "サンプルあたりの最大長",
49
+ "Maximum Training Steps": "最大トレーニングステップ数",
50
+ "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
51
+ "Merge": "マージ",
52
+ "Merge LoRA": "LoRAのマージ",
53
+ "Merge successfully": "マージに成功しました",
54
+ "Minimum Audio Duration": "最小オーディオの長さ",
55
+ "Model Output Path": "モデル出力パス",
56
+ "Model Size": "モデルサイズ",
57
+ "Move": "移動",
58
+ "Move files successfully": "ファイルの移動に成功しました",
59
+ "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
60
+ "No selected options": "選択されたオプションはありません",
61
+ "Number of Workers": "ワーカー数",
62
+ "Open Inference Server": "推論サーバーを開く",
63
+ "Open Labeler WebUI": "ラベラーWebUIを開く",
64
+ "Open Tensorboard": "Tensorboardを開く",
65
+ "Opened labeler in browser": "ブラウザでラベラーを開きました",
66
+ "Optional Label Language": "オプションのラベル言語",
67
+ "Optional online ver": "オプションのオンラインバージョン",
68
+ "Output Path": "出力パス",
69
+ "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
70
+ "Precision": "精度",
71
+ "Probability of applying Speaker Condition": "話者条件を適用する確率",
72
+ "Put your text here.": "ここにテキストを入力してください。",
73
+ "Reference Audio": "リファレンスオーディオ",
74
+ "Reference Text": "リファレンステキスト",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
76
+ "Remove Selected Data": "選択したデータを削除",
77
+ "Removed path successfully!": "パスの削除に成功しました!",
78
+ "Repetition Penalty": "反復ペナルティ",
79
+ "Save model every n steps": "nステップごとにモデルを保存",
80
+ "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
81
+ "Select VITS ckpt": "VITS チェックポイントを選択",
82
+ "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
83
+ "Select source file processing method": "ソースファイルの処理方法を選択",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
85
+ "Selected: {}": "選択済み: {}",
86
+ "Speaker": "話者",
87
+ "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
88
+ "Start Training": "トレーニング開始",
89
+ "Streaming Audio": "ストリーミングオーディオ",
90
+ "Streaming Generate": "ストリーミング合成",
91
+ "Tensorboard Host": "Tensorboardホスト",
92
+ "Tensorboard Log Path": "Tensorboardログパス",
93
+ "Tensorboard Port": "Tensorboardポート",
94
+ "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
95
+ "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
96
+ "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
98
+ "Training Configuration": "トレーニング設定",
99
+ "Training Error": "トレーニングエラー",
100
+ "Training stopped": "トレーニングが停止しました",
101
+ "Type name of the speaker": "話者の名前を入力",
102
+ "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
103
+ "Use LoRA": "LoRAを使用",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
105
+ "Use filelist": "ファイルリストを使用",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
107
+ "VITS Configuration": "VITS の構成",
108
+ "VQGAN Configuration": "VQGAN の構成",
109
+ "Validation Batch Size": "検証バッチサイズ",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
112
+ "WebUI Host": "WebUIホスト",
113
+ "WebUI Port": "WebUIポート",
114
+ "Whisper Model": "Whisperモデル",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
117
+ "latest": "最新",
118
+ "new": "新規",
119
+ "Realtime Transform Text": "リアルタイム変換テキスト",
120
+ "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
121
+ "Text Normalization": "テキスト正規化",
122
+ "Select Example Audio": "サンプル音声を選択"
123
+ }
fish_speech/i18n/locale/ko_KR.json ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
5
+ "Accumulate Gradient Batches": "그라디언트 배치 누적",
6
+ "Add to Processing Area": "처리 영역에 추가",
7
+ "Added path successfully!": "경로가 성공적으로 추가되었습니다!",
8
+ "Advanced Config": "고급 설정",
9
+ "Base LLAMA Model": "기본 LLAMA 모델",
10
+ "Batch Inference": "배치 추론",
11
+ "Batch Size": "배치 크기",
12
+ "Changing with the Model Path": "모델 경로에 따라 변경 중",
13
+ "Chinese": "중국어",
14
+ "Compile Model": "모델 컴파일",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
16
+ "Copy": "복사",
17
+ "Data Preprocessing": "데이터 전처리",
18
+ "Data Preprocessing Path": "데이터 전처리 경로",
19
+ "Data Source": "데이터 소스",
20
+ "Decoder Model Config": "디코더 모델 설정",
21
+ "Decoder Model Path": "디코더 모델 경로",
22
+ "Disabled": "비활성화 됨",
23
+ "Enable Reference Audio": "참고 음성 활성화",
24
+ "English": "영어",
25
+ "Error Message": "오류 메시지",
26
+ "File Preprocessing": "파일 전처리",
27
+ "Generate": "생성",
28
+ "Generated Audio": "생성된 오디오",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
30
+ "Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
31
+ "Inference Configuration": "추론 설정",
32
+ "Inference Server Configuration": "추론 서버 설정",
33
+ "Inference Server Error": "추론 서버 오류",
34
+ "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
35
+ "Initial Learning Rate": "초기 학습률",
36
+ "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
37
+ "Input Text": "입력 텍스트",
38
+ "Invalid path: {}": "유효하지 않은 경로: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
40
+ "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
41
+ "Japanese": "일본어",
42
+ "LLAMA Configuration": "LLAMA 설정",
43
+ "LLAMA Model Config": "LLAMA 모델 설정",
44
+ "LLAMA Model Path": "LLAMA 모델 경로",
45
+ "Labeling Device": "라벨링 장치",
46
+ "LoRA Model to be merged": "병합할 LoRA 모델",
47
+ "Maximum Audio Duration": "최대 오디오 길이",
48
+ "Maximum Length per Sample": "샘플당 최대 길이",
49
+ "Maximum Training Steps": "최대 학습 단계",
50
+ "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
51
+ "Merge": "병합",
52
+ "Merge LoRA": "LoRA 병합",
53
+ "Merge successfully": "성공적으로 병합 되었습니다.",
54
+ "Minimum Audio Duration": "최소 오디오 길이",
55
+ "Model Output Path": "모델 출력 경로",
56
+ "Model Size": "모델 크기",
57
+ "Move": "이동",
58
+ "Move files successfully": "파일이 성공적으로 이동되었습니다.",
59
+ "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
60
+ "No selected options": "옵션이 선택되지 않았습니다.",
61
+ "Number of Workers": "작업자 수",
62
+ "Open Inference Server": "추론 서버 열기",
63
+ "Open Labeler WebUI": "라벨러 WebUI 열기",
64
+ "Open Tensorboard": "Tensorboard 열기",
65
+ "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
66
+ "Optional Label Language": "선택적 라벨 언어",
67
+ "Optional online ver": "온라인 버전 선택",
68
+ "Output Path": "출력 경로",
69
+ "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
70
+ "Precision": "정밀도",
71
+ "Probability of applying Speaker Condition": "화자 조건 적용 확률",
72
+ "Put your text here.": "여기에 텍스트를 입력하세요.",
73
+ "Reference Audio": "참고 오디오",
74
+ "Reference Text": "참고 텍스트",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
76
+ "Remove Selected Data": "선택한 데이터 제거",
77
+ "Removed path successfully!": "��로가 성공적으로 제거되었습니다!",
78
+ "Repetition Penalty": "반복 패널티",
79
+ "Save model every n steps": "n 단계마다 모델 저장",
80
+ "Select LLAMA ckpt": "LLAMA ckpt 선택",
81
+ "Select VITS ckpt": "VITS ckpt 선택",
82
+ "Select VQGAN ckpt": "VQGAN ckpt 선택",
83
+ "Select source file processing method": "소스 파일 처리 방법 선택",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
85
+ "Selected: {}": "선택됨: {}",
86
+ "Speaker": "화자",
87
+ "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
88
+ "Start Training": "학습 시작",
89
+ "Streaming Audio": "스트리밍 오디오",
90
+ "Streaming Generate": "스트리밍 생성",
91
+ "Tensorboard Host": "Tensorboard 호스트",
92
+ "Tensorboard Log Path": "Tensorboard 로그 경로",
93
+ "Tensorboard Port": "Tensorboard 포트",
94
+ "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
95
+ "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
96
+ "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
98
+ "Training Configuration": "학습 설정",
99
+ "Training Error": "학습 오류",
100
+ "Training stopped": "학습이 중지되었습니다.",
101
+ "Type name of the speaker": "화자의 이름을 입력하세요.",
102
+ "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
103
+ "Use LoRA": "LoRA 사용",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
105
+ "Use filelist": "파일 목록 사용",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
107
+ "VITS Configuration": "VITS 설정",
108
+ "VQGAN Configuration": "VQGAN 설정",
109
+ "Validation Batch Size": "검증 배치 크기",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
112
+ "WebUI Host": "WebUI 호스트",
113
+ "WebUI Port": "WebUI 포트",
114
+ "Whisper Model": "Whisper 모델",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
117
+ "latest": "최신",
118
+ "new": "새로운",
119
+ "Realtime Transform Text": "실시간 텍스트 변환",
120
+ "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
121
+ "Text Normalization": "텍스트 정규화",
122
+ "Select Example Audio": "예시 오디오 선택"
123
+ }
fish_speech/i18n/locale/pt_BR.json CHANGED
@@ -1,133 +1,133 @@
1
- {
2
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
3
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
4
- "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
5
- "Add to Processing Area": "Adicionar à Área de Processamento",
6
- "Added path successfully!": "Caminho adicionado com sucesso!",
7
- "Advanced Config": "Configuração Avançada",
8
- "Base LLAMA Model": "Modelo LLAMA Base",
9
- "Batch Inference": "Inferência em Lote",
10
- "Batch Size": "Tamanho do Lote",
11
- "Changing with the Model Path": "Alterando com o Caminho do Modelo",
12
-
13
- "Compile Model": "Compilar Modelo",
14
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
15
- "Copy": "Copiar",
16
- "Data Preprocessing": "Pré-processamento de Dados",
17
- "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
18
- "Data Source": "Fonte de Dados",
19
- "Decoder Model Config": "Configuração do Modelo Decodificador",
20
- "Decoder Model Path": "Caminho do Modelo Decodificador",
21
- "Disabled": "Desativado",
22
- "Enable Initial Prompt": "Habilitar Prompt Inicial",
23
- "Enable Reference Audio": "Habilitar Áudio de Referência",
24
- "English": "Inglês",
25
- "Japanese": "Japonês",
26
- "Chinese": "Chinês",
27
- "Portuguese": "Português",
28
- "Spanish": "Espanhol",
29
- "Error Message": "Mensagem de Erro",
30
- "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
31
- "File Preprocessing": "Pré-processamento de Arquivos",
32
- "Generate": "Gerar",
33
- "Generated Audio": "Áudio Gerado",
34
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
35
- "Infer interface is closed": "A interface de inferência foi fechada",
36
- "Inference Configuration": "Configuração de Inferência",
37
- "Inference Server Configuration": "Configuração do Servidor de Inferência",
38
- "Inference Server Error": "Erro do Servidor de Inferência",
39
- "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
40
- "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
41
- "Initial Prompt": "Prompt Inicial",
42
- "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
43
- "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
44
- "Input Text": "Texto de Entrada",
45
- "Invalid path: {}": "Caminho inválido: {}",
46
- "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
47
- "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
48
- "LLAMA Configuration": "Configuração do LLAMA",
49
- "LLAMA Model Config": "Configuração do Modelo LLAMA",
50
- "LLAMA Model Path": "Caminho do Modelo LLAMA",
51
- "Labeling Device": "Dispositivo de Rotulagem",
52
- "LoRA Model to be merged": "Modelo LoRA para mesclagem",
53
- "Maximum Length per Sample": "Comprimento Máximo por Amostra",
54
- "Maximum Training Steps": "Etapas Máximas de Treinamento",
55
- "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
56
- "Merge": "Mesclar",
57
- "Merge LoRA": "Mesclar LoRA",
58
- "Merge successfully": "Mesclado com sucesso",
59
- "Model Output Path": "Caminho de Saída do Modelo",
60
- "Model Quantization": "Quantização do Modelo",
61
- "Model Size": "Tamanho do Modelo",
62
- "Move": "Mover",
63
- "Move files successfully": "Arquivos movidos com sucesso",
64
- "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
65
- "No selected options": "Nenhuma opção selecionada",
66
- "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
67
- "Number of Workers": "Número de Processos",
68
- "Open Inference Server": "Abrir Servidor de Inferência",
69
- "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
70
- "Open Tensorboard": "Abrir Tensorboard",
71
- "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
72
- "Optional Label Language": "Idioma do Rótulo (Opcional)",
73
- "Optional online ver": "Versão online (opcional)",
74
- "Output Path": "Caminho de Saída",
75
- "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
76
- "Post-quantification Precision": "Precisão Pós-quantização",
77
- "Precision": "Precisão",
78
- "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
79
- "Put your text here.": "Insira seu texto aqui.",
80
- "Quantify": "Quantizar",
81
- "Quantify successfully": "Quantizado com sucesso",
82
- "Realtime Transform Text": "Transformar Texto em Tempo Real",
83
- "Reference Audio": "Áudio de Referência",
84
- "Reference Text": "Texto de Referência",
85
- "warning": "Aviso",
86
- "Pre-processing begins...": "O pré-processamento começou!",
87
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
88
- "Remove Selected Data": "Remover Dados Selecionados",
89
- "Removed path successfully!": "Caminho removido com sucesso!",
90
- "Repetition Penalty": "Penalidade de Repetição",
91
- "Save model every n steps": "Salvar modelo a cada n etapas",
92
- "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
93
- "Select source file processing method": "Escolha como processar o arquivo de origem",
94
- "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
95
- "Selected: {}": "Selecionado: {}",
96
- "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
97
- "Start Training": "Iniciar Treinamento",
98
- "Streaming Audio": "Áudio em Streaming",
99
- "Streaming Generate": "Geração em Streaming",
100
- "Tensorboard Host": "Host do Tensorboard",
101
- "Tensorboard Log Path": "Caminho de Log do Tensorboard",
102
- "Tensorboard Port": "Porta do Tensorboard",
103
- "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
104
- "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
105
- "Text Normalization": "Normalização de Texto",
106
- "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
107
- "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
108
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
109
- "Training Configuration": "Configuração de Treinamento",
110
- "Training Error": "Erro de Treinamento",
111
- "Training stopped": "Treinamento interrompido!",
112
- "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
113
- "Use LoRA": "Usar LoRA",
114
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
115
- "Use filelist": "Usar lista de arquivos",
116
- "VQGAN Configuration": "Configuração do VQGAN",
117
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
118
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
119
- "WebUI Host": "Host da WebUI",
120
- "WebUI Port": "Porta da WebUI",
121
- "Whisper Model": "Modelo Whisper",
122
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
123
- "auto": "automático",
124
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
125
- "latest": "mais recente",
126
- "new": "novo",
127
- "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
128
- "You don't need to train this model!": "Não é necessário treinar este modelo!",
129
- "Yes": "Sim",
130
- "No": "Não",
131
- "version:": "versão:",
132
- "author:": "autor:"
133
- }
 
1
+ {
2
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
3
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
4
+ "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
5
+ "Add to Processing Area": "Adicionar à Área de Processamento",
6
+ "Added path successfully!": "Caminho adicionado com sucesso!",
7
+ "Advanced Config": "Configuração Avançada",
8
+ "Base LLAMA Model": "Modelo LLAMA Base",
9
+ "Batch Inference": "Inferência em Lote",
10
+ "Batch Size": "Tamanho do Lote",
11
+ "Changing with the Model Path": "Alterando com o Caminho do Modelo",
12
+
13
+ "Compile Model": "Compilar Modelo",
14
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
15
+ "Copy": "Copiar",
16
+ "Data Preprocessing": "Pré-processamento de Dados",
17
+ "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
18
+ "Data Source": "Fonte de Dados",
19
+ "Decoder Model Config": "Configuração do Modelo Decodificador",
20
+ "Decoder Model Path": "Caminho do Modelo Decodificador",
21
+ "Disabled": "Desativado",
22
+ "Enable Initial Prompt": "Habilitar Prompt Inicial",
23
+ "Enable Reference Audio": "Habilitar Áudio de Referência",
24
+ "English": "Inglês",
25
+ "Japanese": "Japonês",
26
+ "Chinese": "Chinês",
27
+ "Portuguese": "Português",
28
+ "Spanish": "Espanhol",
29
+ "Error Message": "Mensagem de Erro",
30
+ "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
31
+ "File Preprocessing": "Pré-processamento de Arquivos",
32
+ "Generate": "Gerar",
33
+ "Generated Audio": "Áudio Gerado",
34
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
35
+ "Infer interface is closed": "A interface de inferência foi fechada",
36
+ "Inference Configuration": "Configuração de Inferência",
37
+ "Inference Server Configuration": "Configuração do Servidor de Inferência",
38
+ "Inference Server Error": "Erro do Servidor de Inferência",
39
+ "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
40
+ "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
41
+ "Initial Prompt": "Prompt Inicial",
42
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
43
+ "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
44
+ "Input Text": "Texto de Entrada",
45
+ "Invalid path: {}": "Caminho inválido: {}",
46
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
47
+ "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
48
+ "LLAMA Configuration": "Configuração do LLAMA",
49
+ "LLAMA Model Config": "Configuração do Modelo LLAMA",
50
+ "LLAMA Model Path": "Caminho do Modelo LLAMA",
51
+ "Labeling Device": "Dispositivo de Rotulagem",
52
+ "LoRA Model to be merged": "Modelo LoRA para mesclagem",
53
+ "Maximum Length per Sample": "Comprimento Máximo por Amostra",
54
+ "Maximum Training Steps": "Etapas Máximas de Treinamento",
55
+ "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
56
+ "Merge": "Mesclar",
57
+ "Merge LoRA": "Mesclar LoRA",
58
+ "Merge successfully": "Mesclado com sucesso",
59
+ "Model Output Path": "Caminho de Saída do Modelo",
60
+ "Model Quantization": "Quantização do Modelo",
61
+ "Model Size": "Tamanho do Modelo",
62
+ "Move": "Mover",
63
+ "Move files successfully": "Arquivos movidos com sucesso",
64
+ "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
65
+ "No selected options": "Nenhuma opção selecionada",
66
+ "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
67
+ "Number of Workers": "Número de Processos",
68
+ "Open Inference Server": "Abrir Servidor de Inferência",
69
+ "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
70
+ "Open Tensorboard": "Abrir Tensorboard",
71
+ "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
72
+ "Optional Label Language": "Idioma do Rótulo (Opcional)",
73
+ "Optional online ver": "Versão online (opcional)",
74
+ "Output Path": "Caminho de Saída",
75
+ "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
76
+ "Post-quantification Precision": "Precisão Pós-quantização",
77
+ "Precision": "Precisão",
78
+ "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
79
+ "Put your text here.": "Insira seu texto aqui.",
80
+ "Quantify": "Quantizar",
81
+ "Quantify successfully": "Quantizado com sucesso",
82
+ "Realtime Transform Text": "Transformar Texto em Tempo Real",
83
+ "Reference Audio": "Áudio de Referência",
84
+ "Reference Text": "Texto de Referência",
85
+ "warning": "Aviso",
86
+ "Pre-processing begins...": "O pré-processamento começou!",
87
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
88
+ "Remove Selected Data": "Remover Dados Selecionados",
89
+ "Removed path successfully!": "Caminho removido com sucesso!",
90
+ "Repetition Penalty": "Penalidade de Repetição",
91
+ "Save model every n steps": "Salvar modelo a cada n etapas",
92
+ "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
93
+ "Select source file processing method": "Escolha como processar o arquivo de origem",
94
+ "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
95
+ "Selected: {}": "Selecionado: {}",
96
+ "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
97
+ "Start Training": "Iniciar Treinamento",
98
+ "Streaming Audio": "Áudio em Streaming",
99
+ "Streaming Generate": "Geração em Streaming",
100
+ "Tensorboard Host": "Host do Tensorboard",
101
+ "Tensorboard Log Path": "Caminho de Log do Tensorboard",
102
+ "Tensorboard Port": "Porta do Tensorboard",
103
+ "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
104
+ "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
105
+ "Text Normalization": "Normalização de Texto",
106
+ "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
107
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
108
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
109
+ "Training Configuration": "Configuração de Treinamento",
110
+ "Training Error": "Erro de Treinamento",
111
+ "Training stopped": "Treinamento interrompido!",
112
+ "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
113
+ "Use LoRA": "Usar LoRA",
114
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
115
+ "Use filelist": "Usar lista de arquivos",
116
+ "VQGAN Configuration": "Configuração do VQGAN",
117
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
118
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
119
+ "WebUI Host": "Host da WebUI",
120
+ "WebUI Port": "Porta da WebUI",
121
+ "Whisper Model": "Modelo Whisper",
122
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
123
+ "auto": "automático",
124
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
125
+ "latest": "mais recente",
126
+ "new": "novo",
127
+ "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
128
+ "You don't need to train this model!": "Não é necessário treinar este modelo!",
129
+ "Yes": "Sim",
130
+ "No": "Não",
131
+ "version:": "versão:",
132
+ "author:": "autor:"
133
+ }
fish_speech/i18n/locale/zh_CN.json CHANGED
@@ -1,122 +1,123 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
5
- "Accumulate Gradient Batches": "梯度累积批次",
6
- "Add to Processing Area": "加入处理区",
7
- "Added path successfully!": "添加路径成功!",
8
- "Advanced Config": "高级参数",
9
- "Base LLAMA Model": "基础 LLAMA 模型",
10
- "Batch Inference": "批量推理",
11
- "Batch Size": "批次大小",
12
- "Changing with the Model Path": "随模型路径变化",
13
- "Chinese": "中文",
14
- "Compile Model": "编译模型",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
16
- "Copy": "复制",
17
- "Data Preprocessing": "数据预处理",
18
- "Data Preprocessing Path": "数据预处理路径",
19
- "Data Source": "数据源",
20
- "Decoder Model Config": "解码器模型配置",
21
- "Decoder Model Path": "解码器模型路径",
22
- "Disabled": "禁用",
23
- "Enable Reference Audio": "启用参考音频",
24
- "English": "英文",
25
- "Error Message": "错误信息",
26
- "File Preprocessing": "文件预处理",
27
- "Generate": "生成",
28
- "Generated Audio": "音频",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
30
- "Infer interface is closed": "推理界面已关闭",
31
- "Inference Configuration": "推理配置",
32
- "Inference Server Configuration": "推理服务器配置",
33
- "Inference Server Error": "推理服务器错误",
34
- "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
35
- "Initial Learning Rate": "初始学习率",
36
- "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
37
- "Input Text": "输入文本",
38
- "Invalid path: {}": "无效路径: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
40
- "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
41
- "Japanese": "日文",
42
- "LLAMA Configuration": "LLAMA 配置",
43
- "LLAMA Model Config": "LLAMA 模型配置",
44
- "LLAMA Model Path": "LLAMA 模型路径",
45
- "Labeling Device": "标注加速设备",
46
- "LoRA Model to be merged": "要合并的 LoRA 模型",
47
- "Maximum Audio Duration": "最大音频时长",
48
- "Maximum Length per Sample": "每个样本的最大长度",
49
- "Maximum Training Steps": "最大训练步数",
50
- "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
51
- "Merge": "合并",
52
- "Merge LoRA": "合并 LoRA",
53
- "Merge successfully": "合并成功",
54
- "Minimum Audio Duration": "最小音频时长",
55
- "Model Output Path": "模型输出路径",
56
- "Model Size": "模型规模",
57
- "Move": "移动",
58
- "Move files successfully": "移动文件成功",
59
- "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
60
- "No selected options": "没有选择的选项",
61
- "Number of Workers": "数据加载进程数",
62
- "Open Inference Server": "打开推理服务器",
63
- "Open Labeler WebUI": "打开标注工具",
64
- "Open Tensorboard": "打开 Tensorboard",
65
- "Opened labeler in browser": "在浏览器中打开标注工具",
66
- "Optional Label Language": "[可选] 标注语言",
67
- "Optional online ver": "[可选] 使用在线版",
68
- "Output Path": "输出路径",
69
- "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
70
- "Precision": "精度",
71
- "Probability of applying Speaker Condition": "应用说话人条件的概率",
72
- "Put your text here.": "在此处输入文本.",
73
- "Reference Audio": "参考音频",
74
- "Reference Text": "参考文本",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
76
- "Remove Selected Data": "移除选中数据",
77
- "Removed path successfully!": "移除路径成功!",
78
- "Repetition Penalty": "重复惩罚",
79
- "Save model every n steps": "每 n 步保存模型",
80
- "Select LLAMA ckpt": "选择 LLAMA 检查点",
81
- "Select VITS ckpt": "选择 VITS 检查点",
82
- "Select VQGAN ckpt": "选择 VQGAN 检查点",
83
- "Select source file processing method": "选择源文件处理方法",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
85
- "Selected: {}": "已选择: {}",
86
- "Speaker": "说话人",
87
- "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
88
- "Start Training": "开始训练",
89
- "Streaming Audio": "流式音频",
90
- "Streaming Generate": "流式合成",
91
- "Tensorboard Host": "Tensorboard 监听地址",
92
- "Tensorboard Log Path": "Tensorboard 日志路径",
93
- "Tensorboard Port": "Tensorboard 端口",
94
- "Tensorboard interface is closed": "Tensorboard 界面已关闭",
95
- "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
96
- "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
98
- "Training Configuration": "训练配置",
99
- "Training Error": "训练错误",
100
- "Training stopped": "训练已停止",
101
- "Type name of the speaker": "输入说话人的名称",
102
- "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
103
- "Use LoRA": "使用 LoRA",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
105
- "Use filelist": "使用文件列表",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
107
- "VITS Configuration": "VITS 配置",
108
- "VQGAN Configuration": "VQGAN 配置",
109
- "Validation Batch Size": "验证批次大小",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
112
- "WebUI Host": "WebUI 监听地址",
113
- "WebUI Port": "WebUI 端口",
114
- "Whisper Model": "Whisper 模型",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
117
- "latest": "最近的检查点",
118
- "new": "创建新的检查点",
119
- "Realtime Transform Text": "实时规范化文本",
120
- "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
121
- "Text Normalization": "文本规范化"
122
- }
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
5
+ "Accumulate Gradient Batches": "梯度累积批次",
6
+ "Add to Processing Area": "加入处理区",
7
+ "Added path successfully!": "添加路径成功!",
8
+ "Advanced Config": "高级参数",
9
+ "Base LLAMA Model": "基础 LLAMA 模型",
10
+ "Batch Inference": "批量推理",
11
+ "Batch Size": "批次大小",
12
+ "Changing with the Model Path": "随模型路径变化",
13
+ "Chinese": "中文",
14
+ "Compile Model": "编译模型",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
16
+ "Copy": "复制",
17
+ "Data Preprocessing": "数据预处理",
18
+ "Data Preprocessing Path": "数据预处理路径",
19
+ "Data Source": "数据源",
20
+ "Decoder Model Config": "解码器模型配置",
21
+ "Decoder Model Path": "解码器模型路径",
22
+ "Disabled": "禁用",
23
+ "Enable Reference Audio": "启用参考音频",
24
+ "English": "英文",
25
+ "Error Message": "错误信息",
26
+ "File Preprocessing": "文件预处理",
27
+ "Generate": "生成",
28
+ "Generated Audio": "音频",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
30
+ "Infer interface is closed": "推理界面已关闭",
31
+ "Inference Configuration": "推理配置",
32
+ "Inference Server Configuration": "推理服务器配置",
33
+ "Inference Server Error": "推理服务器错误",
34
+ "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
35
+ "Initial Learning Rate": "初始学习率",
36
+ "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
37
+ "Input Text": "输入文本",
38
+ "Invalid path: {}": "无效路径: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使��� CPU",
40
+ "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
41
+ "Japanese": "日文",
42
+ "LLAMA Configuration": "LLAMA 配置",
43
+ "LLAMA Model Config": "LLAMA 模型配置",
44
+ "LLAMA Model Path": "LLAMA 模型路径",
45
+ "Labeling Device": "标注加速设备",
46
+ "LoRA Model to be merged": "要合并的 LoRA 模型",
47
+ "Maximum Audio Duration": "最大音频时长",
48
+ "Maximum Length per Sample": "每个样本的最大长度",
49
+ "Maximum Training Steps": "最大训练步数",
50
+ "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
51
+ "Merge": "合并",
52
+ "Merge LoRA": "合并 LoRA",
53
+ "Merge successfully": "合并成功",
54
+ "Minimum Audio Duration": "最小音频时长",
55
+ "Model Output Path": "模型输出路径",
56
+ "Model Size": "模型规模",
57
+ "Move": "移动",
58
+ "Move files successfully": "移动文件成功",
59
+ "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
60
+ "No selected options": "没有选择的选项",
61
+ "Number of Workers": "数据加载进程数",
62
+ "Open Inference Server": "打开推理服务器",
63
+ "Open Labeler WebUI": "打开标注工具",
64
+ "Open Tensorboard": "打开 Tensorboard",
65
+ "Opened labeler in browser": "在浏览器中打开标注工具",
66
+ "Optional Label Language": "[可选] 标注语言",
67
+ "Optional online ver": "[可选] 使用在线版",
68
+ "Output Path": "输出路径",
69
+ "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
70
+ "Precision": "精度",
71
+ "Probability of applying Speaker Condition": "应用说话人条件的概率",
72
+ "Put your text here.": "在此处输入文本.",
73
+ "Reference Audio": "参考音频",
74
+ "Reference Text": "参考文本",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
76
+ "Remove Selected Data": "移除选中数据",
77
+ "Removed path successfully!": "移除路径成功!",
78
+ "Repetition Penalty": "重复惩罚",
79
+ "Save model every n steps": "每 n 步保存模型",
80
+ "Select LLAMA ckpt": "选择 LLAMA 检查点",
81
+ "Select VITS ckpt": "选择 VITS 检查点",
82
+ "Select VQGAN ckpt": "选择 VQGAN 检查点",
83
+ "Select source file processing method": "选择源文件处理方法",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
85
+ "Selected: {}": "已选择: {}",
86
+ "Speaker": "说话人",
87
+ "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
88
+ "Start Training": "开始训练",
89
+ "Streaming Audio": "流式音频",
90
+ "Streaming Generate": "流式合成",
91
+ "Tensorboard Host": "Tensorboard 监听地址",
92
+ "Tensorboard Log Path": "Tensorboard 日志路径",
93
+ "Tensorboard Port": "Tensorboard 端口",
94
+ "Tensorboard interface is closed": "Tensorboard 界面已关闭",
95
+ "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
96
+ "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
98
+ "Training Configuration": "训练配置",
99
+ "Training Error": "训练错误",
100
+ "Training stopped": "训练已停止",
101
+ "Type name of the speaker": "输入说话人的名称",
102
+ "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
103
+ "Use LoRA": "使用 LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
105
+ "Use filelist": "使用文件列表",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
107
+ "VITS Configuration": "VITS 配置",
108
+ "VQGAN Configuration": "VQGAN 配置",
109
+ "Validation Batch Size": "验证批次大小",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
112
+ "WebUI Host": "WebUI 监听地址",
113
+ "WebUI Port": "WebUI 端口",
114
+ "Whisper Model": "Whisper 模型",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
117
+ "latest": "最近的检查点",
118
+ "new": "创建新的检查点",
119
+ "Realtime Transform Text": "实时规范化文本",
120
+ "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
121
+ "Text Normalization": "文本规范化",
122
+ "Select Example Audio": "选择参考音频"
123
+ }
fish_speech/i18n/scan.py CHANGED
@@ -1,122 +1,122 @@
1
- import ast
2
- import glob
3
- import json
4
- from collections import OrderedDict
5
- from pathlib import Path
6
-
7
- from loguru import logger
8
-
9
- from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
10
-
11
-
12
- def extract_i18n_strings(node):
13
- i18n_strings = []
14
-
15
- if (
16
- isinstance(node, ast.Call)
17
- and isinstance(node.func, ast.Name)
18
- and node.func.id == "i18n"
19
- ):
20
- for arg in node.args:
21
- if isinstance(arg, ast.Str):
22
- i18n_strings.append(arg.s)
23
-
24
- for child_node in ast.iter_child_nodes(node):
25
- i18n_strings.extend(extract_i18n_strings(child_node))
26
-
27
- return i18n_strings
28
-
29
-
30
- # scan the directory for all .py files (recursively)
31
- # for each file, parse the code into an AST
32
- # for each AST, extract the i18n strings
33
-
34
- strings = []
35
- folders = ["fish_speech", "tools"]
36
- # for filename in glob.iglob("**/*.py", recursive=True):
37
- for folder in folders:
38
- for f in Path(folder).rglob("*.py"):
39
- code = f.read_text(encoding="utf-8")
40
- if "i18n(" in code:
41
- tree = ast.parse(code)
42
- i18n_strings = extract_i18n_strings(tree)
43
- logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
44
- strings.extend(i18n_strings)
45
-
46
- code_keys = set(strings)
47
- logger.info(f"Total unique: {len(code_keys)}")
48
-
49
-
50
- standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
51
- with open(standard_file, "r", encoding="utf-8") as f:
52
- standard_data = json.load(f, object_pairs_hook=OrderedDict)
53
- standard_keys = set(standard_data.keys())
54
-
55
- # Define the standard file name
56
- unused_keys = standard_keys - code_keys
57
- logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
58
- for unused_key in unused_keys:
59
- logger.info(f"\t{unused_key}")
60
-
61
- missing_keys = code_keys - standard_keys
62
- logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
63
- for missing_key in missing_keys:
64
- logger.info(f"\t{missing_key}")
65
-
66
- code_keys_dict = OrderedDict()
67
- for s in strings:
68
- code_keys_dict[s] = s
69
-
70
- # write back
71
- with open(standard_file, "w", encoding="utf-8") as f:
72
- json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
73
- f.write("\n")
74
-
75
- logger.info(f"Updated {standard_file}")
76
-
77
-
78
- # Define the standard file name
79
- standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
80
-
81
- # Find all JSON files in the directory
82
- dir_path = I18N_FILE_PATH
83
- languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
84
-
85
- # Load the standard file
86
- with open(standard_file, "r", encoding="utf-8") as f:
87
- standard_data = json.load(f, object_pairs_hook=OrderedDict)
88
-
89
- # Loop through each language file
90
- for lang_file in languages:
91
- # Load the language file
92
- with open(lang_file, "r", encoding="utf-8") as f:
93
- lang_data = json.load(f, object_pairs_hook=OrderedDict)
94
-
95
- # Find the difference between the language file and the standard file
96
- diff = set(standard_data.keys()) - set(lang_data.keys())
97
-
98
- miss = set(lang_data.keys()) - set(standard_data.keys())
99
-
100
- # Add any missing keys to the language file
101
- for key in diff:
102
- lang_data[key] = "#!" + key
103
- logger.info(f"Added missing key: {key} to {lang_file}")
104
-
105
- # Del any extra keys to the language file
106
- for key in miss:
107
- del lang_data[key]
108
- logger.info(f"Del extra key: {key} from {lang_file}")
109
-
110
- # Sort the keys of the language file to match the order of the standard file
111
- lang_data = OrderedDict(
112
- sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
113
- )
114
-
115
- # Save the updated language file
116
- with open(lang_file, "w", encoding="utf-8") as f:
117
- json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
118
- f.write("\n")
119
-
120
- logger.info(f"Updated {lang_file}")
121
-
122
- logger.info("Done")
 
1
+ import ast
2
+ import glob
3
+ import json
4
+ from collections import OrderedDict
5
+ from pathlib import Path
6
+
7
+ from loguru import logger
8
+
9
+ from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
10
+
11
+
12
+ def extract_i18n_strings(node):
13
+ i18n_strings = []
14
+
15
+ if (
16
+ isinstance(node, ast.Call)
17
+ and isinstance(node.func, ast.Name)
18
+ and node.func.id == "i18n"
19
+ ):
20
+ for arg in node.args:
21
+ if isinstance(arg, ast.Str):
22
+ i18n_strings.append(arg.s)
23
+
24
+ for child_node in ast.iter_child_nodes(node):
25
+ i18n_strings.extend(extract_i18n_strings(child_node))
26
+
27
+ return i18n_strings
28
+
29
+
30
+ # scan the directory for all .py files (recursively)
31
+ # for each file, parse the code into an AST
32
+ # for each AST, extract the i18n strings
33
+
34
+ strings = []
35
+ folders = ["fish_speech", "tools"]
36
+ # for filename in glob.iglob("**/*.py", recursive=True):
37
+ for folder in folders:
38
+ for f in Path(folder).rglob("*.py"):
39
+ code = f.read_text(encoding="utf-8")
40
+ if "i18n(" in code:
41
+ tree = ast.parse(code)
42
+ i18n_strings = extract_i18n_strings(tree)
43
+ logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
44
+ strings.extend(i18n_strings)
45
+
46
+ code_keys = set(strings)
47
+ logger.info(f"Total unique: {len(code_keys)}")
48
+
49
+
50
+ standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
51
+ with open(standard_file, "r", encoding="utf-8") as f:
52
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
53
+ standard_keys = set(standard_data.keys())
54
+
55
+ # Define the standard file name
56
+ unused_keys = standard_keys - code_keys
57
+ logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
58
+ for unused_key in unused_keys:
59
+ logger.info(f"\t{unused_key}")
60
+
61
+ missing_keys = code_keys - standard_keys
62
+ logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
63
+ for missing_key in missing_keys:
64
+ logger.info(f"\t{missing_key}")
65
+
66
+ code_keys_dict = OrderedDict()
67
+ for s in strings:
68
+ code_keys_dict[s] = s
69
+
70
+ # write back
71
+ with open(standard_file, "w", encoding="utf-8") as f:
72
+ json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
73
+ f.write("\n")
74
+
75
+ logger.info(f"Updated {standard_file}")
76
+
77
+
78
+ # Define the standard file name
79
+ standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
80
+
81
+ # Find all JSON files in the directory
82
+ dir_path = I18N_FILE_PATH
83
+ languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
84
+
85
+ # Load the standard file
86
+ with open(standard_file, "r", encoding="utf-8") as f:
87
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
88
+
89
+ # Loop through each language file
90
+ for lang_file in languages:
91
+ # Load the language file
92
+ with open(lang_file, "r", encoding="utf-8") as f:
93
+ lang_data = json.load(f, object_pairs_hook=OrderedDict)
94
+
95
+ # Find the difference between the language file and the standard file
96
+ diff = set(standard_data.keys()) - set(lang_data.keys())
97
+
98
+ miss = set(lang_data.keys()) - set(standard_data.keys())
99
+
100
+ # Add any missing keys to the language file
101
+ for key in diff:
102
+ lang_data[key] = "#!" + key
103
+ logger.info(f"Added missing key: {key} to {lang_file}")
104
+
105
+ # Del any extra keys to the language file
106
+ for key in miss:
107
+ del lang_data[key]
108
+ logger.info(f"Del extra key: {key} from {lang_file}")
109
+
110
+ # Sort the keys of the language file to match the order of the standard file
111
+ lang_data = OrderedDict(
112
+ sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
113
+ )
114
+
115
+ # Save the updated language file
116
+ with open(lang_file, "w", encoding="utf-8") as f:
117
+ json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
118
+ f.write("\n")
119
+
120
+ logger.info(f"Updated {lang_file}")
121
+
122
+ logger.info("Done")
fish_speech/models/text2semantic/lit_module.py CHANGED
@@ -1,202 +1,202 @@
1
- from typing import Any, Optional
2
-
3
- import lightning as L
4
- import torch
5
- import torch.nn.functional as F
6
- from lightning.pytorch.utilities.types import OptimizerLRScheduler
7
-
8
- import fish_speech.utils as utils
9
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
10
- from fish_speech.models.text2semantic.llama import NaiveTransformer
11
-
12
- log = utils.RankedLogger(__name__, rank_zero_only=True)
13
-
14
-
15
- class TextToSemantic(L.LightningModule):
16
- def __init__(
17
- self,
18
- model: NaiveTransformer,
19
- optimizer: Any,
20
- lr_scheduler: Any,
21
- ):
22
- super().__init__()
23
-
24
- self.model = model
25
- self.optimizer_builder = optimizer
26
- self.lr_scheduler_builder = lr_scheduler
27
-
28
- def forward(self, x):
29
- return self.model(x)
30
-
31
- def on_save_checkpoint(self, checkpoint):
32
- # Save only LoRA parameters
33
- state_dict = checkpoint["state_dict"]
34
- use_lora = any("lora" in name for name in state_dict.keys())
35
- if not use_lora:
36
- return
37
-
38
- for name in list(state_dict.keys()):
39
- if "lora" not in name:
40
- state_dict.pop(name)
41
-
42
- def configure_optimizers(self) -> OptimizerLRScheduler:
43
- # Get weight decay parameters
44
- weight_decay_parameters, other_parameters = [], []
45
- for name, param in self.named_parameters():
46
- if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
47
- other_parameters.append(param)
48
- else:
49
- weight_decay_parameters.append(param)
50
-
51
- optimizer = self.optimizer_builder(
52
- [
53
- {"params": weight_decay_parameters},
54
- {"params": other_parameters, "weight_decay": 0.0},
55
- ]
56
- )
57
-
58
- # Print the parameters and their weight decay
59
- for i in optimizer.param_groups:
60
- log.info(
61
- f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
62
- )
63
-
64
- lr_scheduler = self.lr_scheduler_builder(optimizer)
65
-
66
- return {
67
- "optimizer": optimizer,
68
- "lr_scheduler": {
69
- "scheduler": lr_scheduler,
70
- "interval": "step",
71
- },
72
- }
73
-
74
- # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
75
- def get_batch_logps(
76
- self,
77
- logits: torch.FloatTensor,
78
- labels: torch.LongTensor,
79
- average_log_prob: bool = False,
80
- ) -> torch.FloatTensor:
81
- """Compute the log probabilities of the given labels under the given logits.
82
-
83
- Args:
84
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
85
- labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
86
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
87
-
88
- Returns:
89
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
90
- """
91
- assert logits.shape[:-1] == labels.shape
92
-
93
- labels = labels.clone()
94
- loss_mask = labels != -100
95
-
96
- # dummy token; we'll ignore the losses on these tokens later
97
- labels[labels == -100] = 0
98
-
99
- per_token_logps = torch.gather(
100
- logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
101
- ).squeeze(-1)
102
-
103
- if average_log_prob:
104
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
105
- else:
106
- return (per_token_logps * loss_mask).sum(-1)
107
-
108
- def _step(self, batch, batch_idx, stage: str):
109
- is_train = stage == "train"
110
-
111
- if is_train:
112
- # Key part to make lora work
113
- # Otherwise the parameters are merged, which lead to incorrect gradients
114
- self.model.train()
115
-
116
- # Do positive and negative samples in the same batch to speed up training
117
- labels = batch["labels"]
118
- outputs = self.model(
119
- inp=batch["inputs"],
120
- key_padding_mask=batch["attention_masks"],
121
- )
122
- token_logits = outputs.token_logits
123
- codebook_logits = outputs.codebook_logits
124
-
125
- # Generate labels
126
- base_loss = F.cross_entropy(
127
- token_logits.view(-1, token_logits.size(-1)),
128
- labels[:, 0].reshape(-1),
129
- ignore_index=-100,
130
- )
131
-
132
- codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
133
- semantic_loss = F.cross_entropy(
134
- codebook_logits.view(-1, codebook_logits.size(-1)),
135
- codebook_labels.reshape(-1),
136
- ignore_index=-100,
137
- )
138
-
139
- loss = base_loss + semantic_loss
140
-
141
- self.log(
142
- f"{stage}/loss",
143
- loss,
144
- on_step=is_train,
145
- on_epoch=not is_train,
146
- prog_bar=True,
147
- logger=True,
148
- sync_dist=not is_train,
149
- )
150
-
151
- self.log(
152
- f"{stage}/base_loss",
153
- base_loss,
154
- on_step=is_train,
155
- on_epoch=not is_train,
156
- prog_bar=False,
157
- logger=True,
158
- sync_dist=not is_train,
159
- )
160
-
161
- self.log(
162
- f"{stage}/semantic_loss",
163
- semantic_loss,
164
- on_step=is_train,
165
- on_epoch=not is_train,
166
- prog_bar=False,
167
- logger=True,
168
- sync_dist=not is_train,
169
- )
170
-
171
- # Top-5 accuracy
172
- accuracy = self.get_accuracy(codebook_logits, codebook_labels)
173
- self.log(
174
- f"{stage}/top_5_accuracy",
175
- accuracy,
176
- on_step=is_train,
177
- on_epoch=not is_train,
178
- prog_bar=True,
179
- logger=True,
180
- sync_dist=not is_train,
181
- )
182
-
183
- return loss
184
-
185
- def get_accuracy(self, logits, labels):
186
- mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
187
- if mask.sum() == 0:
188
- return torch.tensor(0.0, device=logits.device)
189
-
190
- _, indices = logits.topk(5, dim=-1)
191
- correct = indices.eq(labels.unsqueeze(-1))
192
- correct[~mask] = 0
193
- correct = correct.sum()
194
- accuracy = correct / mask.sum()
195
-
196
- return accuracy
197
-
198
- def training_step(self, batch, batch_idx):
199
- return self._step(batch, batch_idx, "train")
200
-
201
- def validation_step(self, batch, batch_idx):
202
- return self._step(batch, batch_idx, "val")
 
1
+ from typing import Any, Optional
2
+
3
+ import lightning as L
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from lightning.pytorch.utilities.types import OptimizerLRScheduler
7
+
8
+ import fish_speech.utils as utils
9
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
10
+ from fish_speech.models.text2semantic.llama import NaiveTransformer
11
+
12
+ log = utils.RankedLogger(__name__, rank_zero_only=True)
13
+
14
+
15
+ class TextToSemantic(L.LightningModule):
16
+ def __init__(
17
+ self,
18
+ model: NaiveTransformer,
19
+ optimizer: Any,
20
+ lr_scheduler: Any,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.model = model
25
+ self.optimizer_builder = optimizer
26
+ self.lr_scheduler_builder = lr_scheduler
27
+
28
+ def forward(self, x):
29
+ return self.model(x)
30
+
31
+ def on_save_checkpoint(self, checkpoint):
32
+ # Save only LoRA parameters
33
+ state_dict = checkpoint["state_dict"]
34
+ use_lora = any("lora" in name for name in state_dict.keys())
35
+ if not use_lora:
36
+ return
37
+
38
+ for name in list(state_dict.keys()):
39
+ if "lora" not in name:
40
+ state_dict.pop(name)
41
+
42
+ def configure_optimizers(self) -> OptimizerLRScheduler:
43
+ # Get weight decay parameters
44
+ weight_decay_parameters, other_parameters = [], []
45
+ for name, param in self.named_parameters():
46
+ if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
47
+ other_parameters.append(param)
48
+ else:
49
+ weight_decay_parameters.append(param)
50
+
51
+ optimizer = self.optimizer_builder(
52
+ [
53
+ {"params": weight_decay_parameters},
54
+ {"params": other_parameters, "weight_decay": 0.0},
55
+ ]
56
+ )
57
+
58
+ # Print the parameters and their weight decay
59
+ for i in optimizer.param_groups:
60
+ log.info(
61
+ f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
62
+ )
63
+
64
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
65
+
66
+ return {
67
+ "optimizer": optimizer,
68
+ "lr_scheduler": {
69
+ "scheduler": lr_scheduler,
70
+ "interval": "step",
71
+ },
72
+ }
73
+
74
+ # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
75
+ def get_batch_logps(
76
+ self,
77
+ logits: torch.FloatTensor,
78
+ labels: torch.LongTensor,
79
+ average_log_prob: bool = False,
80
+ ) -> torch.FloatTensor:
81
+ """Compute the log probabilities of the given labels under the given logits.
82
+
83
+ Args:
84
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
85
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
86
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
87
+
88
+ Returns:
89
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
90
+ """
91
+ assert logits.shape[:-1] == labels.shape
92
+
93
+ labels = labels.clone()
94
+ loss_mask = labels != -100
95
+
96
+ # dummy token; we'll ignore the losses on these tokens later
97
+ labels[labels == -100] = 0
98
+
99
+ per_token_logps = torch.gather(
100
+ logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
101
+ ).squeeze(-1)
102
+
103
+ if average_log_prob:
104
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
105
+ else:
106
+ return (per_token_logps * loss_mask).sum(-1)
107
+
108
+ def _step(self, batch, batch_idx, stage: str):
109
+ is_train = stage == "train"
110
+
111
+ if is_train:
112
+ # Key part to make lora work
113
+ # Otherwise the parameters are merged, which lead to incorrect gradients
114
+ self.model.train()
115
+
116
+ # Do positive and negative samples in the same batch to speed up training
117
+ labels = batch["labels"]
118
+ outputs = self.model(
119
+ inp=batch["inputs"],
120
+ key_padding_mask=batch["attention_masks"],
121
+ )
122
+ token_logits = outputs.token_logits
123
+ codebook_logits = outputs.codebook_logits
124
+
125
+ # Generate labels
126
+ base_loss = F.cross_entropy(
127
+ token_logits.view(-1, token_logits.size(-1)),
128
+ labels[:, 0].reshape(-1),
129
+ ignore_index=-100,
130
+ )
131
+
132
+ codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
133
+ semantic_loss = F.cross_entropy(
134
+ codebook_logits.view(-1, codebook_logits.size(-1)),
135
+ codebook_labels.reshape(-1),
136
+ ignore_index=-100,
137
+ )
138
+
139
+ loss = base_loss + semantic_loss
140
+
141
+ self.log(
142
+ f"{stage}/loss",
143
+ loss,
144
+ on_step=is_train,
145
+ on_epoch=not is_train,
146
+ prog_bar=True,
147
+ logger=True,
148
+ sync_dist=not is_train,
149
+ )
150
+
151
+ self.log(
152
+ f"{stage}/base_loss",
153
+ base_loss,
154
+ on_step=is_train,
155
+ on_epoch=not is_train,
156
+ prog_bar=False,
157
+ logger=True,
158
+ sync_dist=not is_train,
159
+ )
160
+
161
+ self.log(
162
+ f"{stage}/semantic_loss",
163
+ semantic_loss,
164
+ on_step=is_train,
165
+ on_epoch=not is_train,
166
+ prog_bar=False,
167
+ logger=True,
168
+ sync_dist=not is_train,
169
+ )
170
+
171
+ # Top-5 accuracy
172
+ accuracy = self.get_accuracy(codebook_logits, codebook_labels)
173
+ self.log(
174
+ f"{stage}/top_5_accuracy",
175
+ accuracy,
176
+ on_step=is_train,
177
+ on_epoch=not is_train,
178
+ prog_bar=True,
179
+ logger=True,
180
+ sync_dist=not is_train,
181
+ )
182
+
183
+ return loss
184
+
185
+ def get_accuracy(self, logits, labels):
186
+ mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
187
+ if mask.sum() == 0:
188
+ return torch.tensor(0.0, device=logits.device)
189
+
190
+ _, indices = logits.topk(5, dim=-1)
191
+ correct = indices.eq(labels.unsqueeze(-1))
192
+ correct[~mask] = 0
193
+ correct = correct.sum()
194
+ accuracy = correct / mask.sum()
195
+
196
+ return accuracy
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ return self._step(batch, batch_idx, "train")
200
+
201
+ def validation_step(self, batch, batch_idx):
202
+ return self._step(batch, batch_idx, "val")
fish_speech/models/text2semantic/llama.py CHANGED
@@ -1,779 +1,887 @@
1
- import json
2
- import math
3
- from collections import OrderedDict
4
- from dataclasses import dataclass
5
- from pathlib import Path
6
- from typing import Optional
7
-
8
- import torch
9
- import torch.nn as nn
10
- from einops import rearrange
11
- from loguru import logger
12
- from torch import Tensor
13
- from torch.nn import functional as F
14
- from torch.nn.attention import SDPBackend, sdpa_kernel
15
- from torch.utils.checkpoint import checkpoint
16
- from transformers import AutoTokenizer
17
-
18
- from fish_speech.conversation import SEMANTIC_TOKEN
19
- from fish_speech.utils import RankedLogger
20
-
21
- from .lora import LoraConfig, setup_lora
22
-
23
- log = RankedLogger(__name__, rank_zero_only=True)
24
-
25
-
26
- def find_multiple(n: int, k: int) -> int:
27
- if n % k == 0:
28
- return n
29
- return n + k - (n % k)
30
-
31
-
32
- @dataclass
33
- class BaseModelArgs:
34
- model_type: str = "base"
35
-
36
- vocab_size: int = 32000
37
- n_layer: int = 32
38
- n_head: int = 32
39
- dim: int = 4096
40
- intermediate_size: int = None
41
- n_local_heads: int = -1
42
- head_dim: int = 64
43
- rope_base: float = 10000
44
- norm_eps: float = 1e-5
45
- max_seq_len: int = 2048
46
- dropout: float = 0.0
47
- tie_word_embeddings: bool = True
48
- attention_qkv_bias: bool = False
49
-
50
- # Codebook configs
51
- codebook_size: int = 160
52
- num_codebooks: int = 4
53
-
54
- # Gradient checkpointing
55
- use_gradient_checkpointing: bool = True
56
-
57
- # Initialize the model
58
- initializer_range: float = 0.02
59
-
60
- def __post_init__(self):
61
- if self.n_local_heads == -1:
62
- self.n_local_heads = self.n_head
63
- if self.intermediate_size is None:
64
- hidden_dim = 4 * self.dim
65
- n_hidden = int(2 * hidden_dim / 3)
66
- self.intermediate_size = find_multiple(n_hidden, 256)
67
- self.head_dim = self.dim // self.n_head
68
-
69
- @staticmethod
70
- def from_pretrained(path: str):
71
- path = Path(path)
72
-
73
- if path.is_dir():
74
- path = path / "config.json"
75
-
76
- with open(path, "r", encoding="utf-8") as f:
77
- data = json.load(f)
78
-
79
- match data["model_type"]:
80
- case "naive":
81
- cls = NaiveModelArgs
82
- case "dual_ar":
83
- cls = DualARModelArgs
84
- case _:
85
- raise ValueError(f"Unknown model type: {data['model_type']}")
86
-
87
- return cls(**data)
88
-
89
- def save(self, path: str):
90
- with open(path, "w") as f:
91
- json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
92
-
93
-
94
- @dataclass
95
- class NaiveModelArgs(BaseModelArgs):
96
- model_type: str = "naive"
97
-
98
-
99
- @dataclass
100
- class DualARModelArgs(BaseModelArgs):
101
- model_type: str = "dual_ar"
102
- n_fast_layer: int = 4
103
-
104
-
105
- class KVCache(nn.Module):
106
- def __init__(
107
- self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
108
- ):
109
- super().__init__()
110
- cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
111
- self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
112
- self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
113
-
114
- def update(self, input_pos, k_val, v_val):
115
- # input_pos: [S], k_val: [B, H, S, D]
116
- assert input_pos.shape[0] == k_val.shape[2]
117
-
118
- k_out = self.k_cache
119
- v_out = self.v_cache
120
- k_out[:, :, input_pos] = k_val
121
- v_out[:, :, input_pos] = v_val
122
-
123
- return k_out, v_out
124
-
125
-
126
- @dataclass
127
- class TransformerForwardResult:
128
- token_logits: Tensor
129
- codebook_logits: Tensor
130
-
131
-
132
- @dataclass
133
- class BaseTransformerForwardResult:
134
- logits: Tensor
135
- hidden_states: Tensor
136
-
137
-
138
- class BaseTransformer(nn.Module):
139
- def __init__(
140
- self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
141
- ) -> None:
142
- super().__init__()
143
- self.config = config
144
- self.tokenizer = tokenizer
145
-
146
- self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
147
-
148
- # Slow transformer
149
- self.embeddings = nn.Embedding(
150
- config.vocab_size,
151
- config.dim,
152
- )
153
- self.codebook_embeddings = nn.Embedding(
154
- config.codebook_size * config.num_codebooks,
155
- config.dim,
156
- )
157
- self.layers = nn.ModuleList(
158
- TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
159
- )
160
- self.norm = RMSNorm(config.dim, eps=config.norm_eps)
161
-
162
- if self.config.tie_word_embeddings is False:
163
- self.output = nn.Linear(
164
- config.dim,
165
- config.vocab_size,
166
- bias=False,
167
- )
168
-
169
- self.register_buffer(
170
- "freqs_cis",
171
- precompute_freqs_cis(
172
- config.max_seq_len,
173
- config.dim // config.n_head,
174
- config.rope_base,
175
- ),
176
- persistent=False,
177
- )
178
- self.register_buffer(
179
- "causal_mask",
180
- torch.tril(
181
- torch.ones(
182
- config.max_seq_len,
183
- config.max_seq_len,
184
- dtype=torch.bool,
185
- )
186
- ),
187
- persistent=False,
188
- )
189
-
190
- # For kv cache
191
- self.max_batch_size = -1
192
- self.max_seq_len = -1
193
-
194
- if init_weights:
195
- self.apply(self._init_weights)
196
-
197
- def setup_caches(
198
- self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
199
- ):
200
- if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
201
- return
202
-
203
- head_dim = self.config.dim // self.config.n_head
204
- max_seq_len = find_multiple(max_seq_len, 8)
205
- self.max_seq_len = max_seq_len
206
- self.max_batch_size = max_batch_size
207
-
208
- for b in self.layers:
209
- b.attention.kv_cache = KVCache(
210
- max_batch_size,
211
- max_seq_len,
212
- self.config.n_local_heads,
213
- head_dim,
214
- dtype=dtype,
215
- )
216
-
217
- def embed(self, x: Tensor) -> Tensor:
218
- vocab_embeds = [self.embeddings(x[:, 0])]
219
- for i in range(self.config.num_codebooks):
220
- emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
221
- emb[x[:, 0] != self.semantic_token_id] = 0
222
- vocab_embeds.append(emb)
223
-
224
- x = torch.stack(vocab_embeds, dim=3)
225
- x = x.sum(dim=3)
226
-
227
- return x
228
-
229
- def forward(
230
- self,
231
- inp: Tensor,
232
- key_padding_mask: Optional[Tensor] = None,
233
- ) -> BaseTransformerForwardResult:
234
- seq_len = inp.size(2)
235
-
236
- # Here we want to merge the embeddings of the codebooks
237
- x = self.embed(inp)
238
-
239
- freqs_cis = self.freqs_cis[:seq_len]
240
-
241
- # Not that the causal mask here follows the definition of scaled_dot_product_attention
242
- # That is, FALSE means masked out
243
- # To maintain consistency, key_padding_mask use TRUE to mask out
244
- mask = None
245
- if key_padding_mask is not None:
246
- mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
247
- mask = mask & key_padding_mask[:, None, None, :].logical_not()
248
-
249
- for layer in self.layers:
250
- if self.config.use_gradient_checkpointing and self.training:
251
- x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
252
- else:
253
- x = layer(x, freqs_cis, mask)
254
-
255
- # We got slow_out here
256
- slow_out = self.norm(x)
257
-
258
- if self.config.tie_word_embeddings:
259
- token_logits = F.linear(slow_out, self.embeddings.weight)
260
- else:
261
- token_logits = self.output(slow_out)
262
-
263
- return BaseTransformerForwardResult(
264
- logits=token_logits,
265
- hidden_states=x,
266
- )
267
-
268
- def forward_generate(
269
- self,
270
- x: Tensor,
271
- input_pos: Optional[Tensor] = None,
272
- return_all: bool = False,
273
- ) -> BaseTransformerForwardResult:
274
- # This is used for generation, optimized for torch compile
275
- assert (
276
- self.max_seq_len != -1 and self.max_batch_size != -1
277
- ), "Please call setup_caches before forward_generate"
278
-
279
- x = self.embed(x)
280
-
281
- mask = self.causal_mask[
282
- None, None, input_pos, : self.max_seq_len
283
- ] # (B, N, Q, K)
284
- freqs_cis = self.freqs_cis[input_pos]
285
-
286
- for layer in self.layers:
287
- x = layer(x, freqs_cis, mask, input_pos=input_pos)
288
-
289
- # If prefill, we only calculate the logits of last token
290
- if x.size(1) > 1 and not return_all:
291
- x = x[:, -1:]
292
-
293
- # We got slow_out here
294
- slow_out = self.norm(x)
295
-
296
- if self.config.tie_word_embeddings:
297
- token_logits = F.linear(slow_out, self.embeddings.weight)
298
- else:
299
- token_logits = self.output(slow_out)
300
-
301
- return BaseTransformerForwardResult(
302
- logits=token_logits,
303
- hidden_states=x,
304
- )
305
-
306
- def _init_weights(self, module):
307
- std = self.config.initializer_range
308
- if isinstance(module, nn.Linear):
309
- module.weight.data.normal_(mean=0.0, std=std)
310
- if module.bias is not None:
311
- module.bias.data.zero_()
312
- elif isinstance(module, nn.Embedding):
313
- module.weight.data.normal_(mean=0.0, std=std)
314
- if module.padding_idx is not None:
315
- module.weight.data[module.padding_idx].zero_()
316
-
317
- @staticmethod
318
- def from_pretrained(
319
- path: str,
320
- load_weights: bool = False,
321
- max_length: int | None = None,
322
- lora_config: LoraConfig | None = None,
323
- rope_base: int | None = None,
324
- ) -> "BaseTransformer":
325
- config = BaseModelArgs.from_pretrained(str(path))
326
- if max_length is not None:
327
- config.max_seq_len = max_length
328
- log.info(f"Override max_seq_len to {max_length}")
329
-
330
- if rope_base is not None:
331
- config.rope_base = rope_base
332
- log.info(f"Override rope_base to {rope_base}")
333
-
334
- match config.model_type:
335
- case "naive":
336
- model_cls = NaiveTransformer
337
- case "dual_ar":
338
- model_cls = DualARTransformer
339
- case _:
340
- raise ValueError(f"Unknown model type: {config.model_type}")
341
-
342
- tokenizer = AutoTokenizer.from_pretrained(str(path))
343
- log.info(f"Loading model from {path}, config: {config}")
344
- model = model_cls(config, tokenizer=tokenizer)
345
-
346
- if lora_config is not None:
347
- setup_lora(model, lora_config)
348
- log.info(f"LoRA setup: {lora_config}")
349
-
350
- if load_weights is False:
351
- log.info("Randomly initialized model")
352
- else:
353
-
354
- if "int8" in str(Path(path)):
355
- logger.info("Using int8 weight-only quantization!")
356
- from tools.llama.quantize import WeightOnlyInt8QuantHandler
357
-
358
- simple_quantizer = WeightOnlyInt8QuantHandler(model)
359
- model = simple_quantizer.convert_for_runtime()
360
-
361
- if "int4" in str(Path(path)):
362
- logger.info("Using int4 quantization!")
363
- path_comps = path.name.split("-")
364
- assert path_comps[-2].startswith("g")
365
- groupsize = int(path_comps[-2][1:])
366
- from tools.llama.quantize import WeightOnlyInt4QuantHandler
367
-
368
- simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
369
- model = simple_quantizer.convert_for_runtime()
370
-
371
- weights = torch.load(
372
- Path(path) / "model.pth", map_location="cpu", mmap=True
373
- )
374
-
375
- if "state_dict" in weights:
376
- logger.warning(
377
- "Using a TextToSemantic LightningModule checkpoint, "
378
- "please make sure it is a full model, not a LoRA model."
379
- )
380
- weights = weights["state_dict"]
381
-
382
- if next(iter(weights.keys())).startswith("model."):
383
- logger.info(
384
- f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
385
- )
386
- new_weights = OrderedDict()
387
- for k, v in weights.items():
388
- new_weights[k.replace("model.", "")] = v
389
- weights = new_weights
390
-
391
- # Verify the name and shape of parameters since strict=False in load_state_dict.
392
- for k, v in model.named_parameters():
393
- if k not in weights:
394
- logger.warning(f"No weight for {k}")
395
- elif v.shape != weights[k].shape:
396
- logger.warning(
397
- f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
398
- )
399
-
400
- err = model.load_state_dict(weights, strict=False, assign=True)
401
- log.info(f"Loaded weights with error: {err}")
402
-
403
- return model
404
-
405
- def save_pretrained(self, path: str, drop_lora: bool = False):
406
- path = Path(path)
407
- path.mkdir(parents=True, exist_ok=True)
408
-
409
- self.config.save(path / "config.json")
410
- state_dict = self.state_dict()
411
-
412
- if drop_lora:
413
- for key in list(state_dict.keys()):
414
- if "lora" not in key:
415
- continue
416
-
417
- state_dict.pop(key)
418
- log.info(f"Drop LoRA parameter: {key}")
419
-
420
- torch.save(state_dict, path / "model.pth")
421
- self.tokenizer.save_pretrained(path)
422
-
423
-
424
- class NaiveTransformer(BaseTransformer):
425
- def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
426
- super().__init__(config, init_weights=False, tokenizer=tokenizer)
427
-
428
- self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
429
- self.codebook_output = nn.Linear(
430
- config.dim,
431
- config.codebook_size * config.num_codebooks,
432
- bias=False,
433
- )
434
-
435
- self.apply(self._init_weights)
436
-
437
- def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
438
- token_logits = result.logits
439
- x = result.hidden_states
440
-
441
- # Codebook
442
- codebook_logits = self.codebook_output(self.codebook_norm(x))
443
- codebook_logits = rearrange(
444
- codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
445
- )
446
-
447
- return TransformerForwardResult(
448
- token_logits=token_logits,
449
- codebook_logits=codebook_logits,
450
- )
451
-
452
- def forward(
453
- self,
454
- inp: Tensor,
455
- key_padding_mask: Optional[Tensor] = None,
456
- ) -> TransformerForwardResult:
457
- result = super().forward(
458
- inp=inp,
459
- key_padding_mask=key_padding_mask,
460
- )
461
- return self.decode(result)
462
-
463
- def forward_generate(
464
- self, x: Tensor, input_pos: Optional[Tensor] = None
465
- ) -> TransformerForwardResult:
466
- result = super().forward_generate(x, input_pos)
467
- return self.decode(result)
468
-
469
-
470
- class DualARTransformer(BaseTransformer):
471
- def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
472
- super().__init__(config, init_weights=False, tokenizer=tokenizer)
473
-
474
- # Fast transformer
475
- self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
476
-
477
- # The equivalent bs is so large that sdpa doesn't work
478
- self.fast_layers = nn.ModuleList(
479
- TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
480
- )
481
- self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
482
- self.fast_output = nn.Linear(
483
- config.dim,
484
- config.codebook_size,
485
- bias=False,
486
- )
487
-
488
- self.apply(self._init_weights)
489
-
490
- def setup_caches(
491
- self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
492
- ):
493
- super().setup_caches(max_batch_size, max_seq_len, dtype)
494
-
495
- head_dim = self.config.dim // self.config.n_head
496
-
497
- # Fast transformer
498
- # The max seq len here is the number of codebooks
499
- for b in self.fast_layers:
500
- b.attention.kv_cache = KVCache(
501
- max_batch_size,
502
- self.config.num_codebooks,
503
- self.config.n_local_heads,
504
- head_dim,
505
- dtype=dtype,
506
- )
507
-
508
- def forward(
509
- self,
510
- inp: Tensor,
511
- key_padding_mask: Optional[Tensor] = None,
512
- ) -> TransformerForwardResult:
513
- parent_result = super().forward(inp, key_padding_mask)
514
- token_logits = parent_result.logits
515
- x = parent_result.hidden_states
516
-
517
- # Fast transformer
518
- fast_seq_len = self.config.num_codebooks
519
- fast_mask = self.causal_mask[
520
- None, None, :fast_seq_len, :fast_seq_len
521
- ] # (B, N, Q, K)
522
- fast_freqs_cis = self.freqs_cis[:fast_seq_len]
523
-
524
- # Drop the last token and rotate left
525
- codebooks = inp[:, 1:-1, 1:]
526
- codebooks = F.pad(codebooks, (0, 1), value=0)
527
- codebook_embeddings = self.fast_embeddings(codebooks)
528
- x = torch.cat([x[:, None], codebook_embeddings], dim=1)
529
- b, s = x.size(0), x.size(2)
530
- x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
531
-
532
- # Remove padded part
533
- codebooks = rearrange(codebooks, "b n s -> (b s) n")
534
- codebook_mask = (codebooks == 0).all(dim=-1)
535
-
536
- if torch.all(codebook_mask):
537
- # If all codebooks are padded, we keep first 8 to make sure the model runs
538
- codebook_mask[:8] = False
539
-
540
- x_bs, x_len = x.size(0), x.size(1)
541
- x = x[~codebook_mask]
542
-
543
- for layer in self.fast_layers:
544
- if self.config.use_gradient_checkpointing and self.training:
545
- x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
546
- else:
547
- x = layer(x, fast_freqs_cis, fast_mask)
548
-
549
- # unflatten the batch and num_codebooks
550
- fast_out = self.fast_norm(x)
551
- codebook_logits = self.fast_output(fast_out)
552
-
553
- # Re-pad the codebook_logits
554
- buffer = torch.zeros(
555
- x_bs,
556
- x_len,
557
- codebook_logits.size(-1),
558
- device=codebook_logits.device,
559
- dtype=codebook_logits.dtype,
560
- )
561
- buffer[~codebook_mask] = codebook_logits
562
- codebook_logits = buffer
563
-
564
- assert codebook_logits.shape[1] == self.config.num_codebooks
565
- codebook_logits = rearrange(
566
- codebook_logits,
567
- "(b s) n d -> b s n d",
568
- b=b,
569
- s=s,
570
- n=self.config.num_codebooks,
571
- )
572
-
573
- return TransformerForwardResult(
574
- token_logits=token_logits,
575
- codebook_logits=codebook_logits,
576
- )
577
-
578
- def forward_generate_fast(
579
- self, x: Tensor, input_pos: Optional[Tensor] = None
580
- ) -> Tensor:
581
- # Fast transformer
582
- x = x.view(1, 1, -1)
583
-
584
- fast_mask = self.causal_mask[
585
- None, None, input_pos, : self.config.num_codebooks
586
- ] # (B, N, Q, K)
587
- fast_freqs_cis = self.freqs_cis[input_pos]
588
-
589
- for layer in self.fast_layers:
590
- x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
591
-
592
- # unflatten the batch and num_codebooks
593
- fast_out = self.fast_norm(x) # only take the last token
594
- codebook_logits = self.fast_output(fast_out)
595
-
596
- return codebook_logits
597
-
598
-
599
- class TransformerBlock(nn.Module):
600
- def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
601
- super().__init__()
602
- self.attention = Attention(config, use_sdpa=use_sdpa)
603
- self.feed_forward = FeedForward(config)
604
- self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
605
- self.attention_norm = RMSNorm(config.dim, config.norm_eps)
606
-
607
- def forward(
608
- self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
609
- ) -> Tensor:
610
- h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
611
- out = h + self.feed_forward(self.ffn_norm(h))
612
- return out
613
-
614
-
615
- class Attention(nn.Module):
616
- def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
617
- super().__init__()
618
- assert config.dim % config.n_head == 0
619
-
620
- total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
621
- # key, query, value projections for all heads, but in a batch
622
- self.wqkv = nn.Linear(
623
- config.dim, total_head_dim, bias=config.attention_qkv_bias
624
- )
625
- self.wo = nn.Linear(config.dim, config.dim, bias=False)
626
- self.kv_cache = None
627
-
628
- self.dropout = config.dropout
629
- self.n_head = config.n_head
630
- self.head_dim = config.head_dim
631
- self.n_local_heads = config.n_local_heads
632
- self.dim = config.dim
633
- self.use_sdpa = use_sdpa
634
- self._register_load_state_dict_pre_hook(self.load_hook)
635
-
636
- def load_hook(self, state_dict, prefix, *args):
637
- if prefix + "wq.weight" in state_dict:
638
- wq = state_dict.pop(prefix + "wq.weight")
639
- wk = state_dict.pop(prefix + "wk.weight")
640
- wv = state_dict.pop(prefix + "wv.weight")
641
- state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
642
-
643
- def forward(
644
- self,
645
- x: Tensor,
646
- freqs_cis: Tensor,
647
- mask: Tensor,
648
- input_pos: Optional[Tensor] = None,
649
- ) -> Tensor:
650
- bsz, seqlen, _ = x.shape
651
-
652
- kv_size = self.n_local_heads * self.head_dim
653
- q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
654
-
655
- q = q.view(bsz, seqlen, self.n_head, self.head_dim)
656
- k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
657
- v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
658
-
659
- q = apply_rotary_emb(q, freqs_cis)
660
- k = apply_rotary_emb(k, freqs_cis)
661
-
662
- q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
663
-
664
- if self.kv_cache is not None:
665
- k, v = self.kv_cache.update(input_pos, k, v)
666
-
667
- k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
668
- v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
669
-
670
- if self.use_sdpa:
671
- if mask is None:
672
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
673
- y = F.scaled_dot_product_attention(
674
- q,
675
- k,
676
- v,
677
- dropout_p=self.dropout if self.training else 0.0,
678
- is_causal=True,
679
- # No third party attn_mask here to use flash_attention
680
- )
681
- else:
682
- y = F.scaled_dot_product_attention(
683
- q,
684
- k,
685
- v,
686
- attn_mask=mask,
687
- dropout_p=self.dropout if self.training else 0.0,
688
- )
689
- else:
690
- y = self.eq_scaled_dot_product_attention(
691
- q,
692
- k,
693
- v,
694
- attn_mask=mask,
695
- dropout_p=self.dropout if self.training else 0.0,
696
- )
697
-
698
- y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
699
-
700
- return self.wo(y)
701
-
702
- def eq_scaled_dot_product_attention(
703
- self,
704
- query,
705
- key,
706
- value,
707
- attn_mask=None,
708
- dropout_p=0.0,
709
- ) -> torch.Tensor:
710
- # This is a standard scaled dot product attention
711
- # It's low efficient, but it doesn't raise cuda error
712
-
713
- L, S = query.size(-2), key.size(-2)
714
- scale_factor = 1 / math.sqrt(query.size(-1))
715
- attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
716
-
717
- if attn_mask is not None:
718
- if attn_mask.dtype == torch.bool:
719
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
720
- else:
721
- attn_bias += attn_mask
722
-
723
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
724
- attn_weight += attn_bias
725
- attn_weight = torch.softmax(attn_weight, dim=-1)
726
- attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
727
-
728
- return attn_weight @ value
729
-
730
-
731
- class FeedForward(nn.Module):
732
- def __init__(self, config: BaseModelArgs) -> None:
733
- super().__init__()
734
- self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
735
- self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
736
- self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
737
-
738
- def forward(self, x: Tensor) -> Tensor:
739
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
740
-
741
-
742
- class RMSNorm(nn.Module):
743
- def __init__(self, dim: int, eps: float = 1e-5):
744
- super().__init__()
745
- self.eps = eps
746
- self.weight = nn.Parameter(torch.ones(dim))
747
-
748
- def _norm(self, x):
749
- return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
750
-
751
- def forward(self, x: Tensor) -> Tensor:
752
- output = self._norm(x.float()).type_as(x)
753
- return output * self.weight
754
-
755
-
756
- def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
757
- freqs = 1.0 / (
758
- base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
759
- )
760
- t = torch.arange(seq_len, device=freqs.device)
761
- freqs = torch.outer(t, freqs)
762
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
763
- cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
764
- return cache.to(dtype=torch.bfloat16)
765
-
766
-
767
- def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
768
- xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
769
- freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
770
- x_out2 = torch.stack(
771
- [
772
- xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
773
- xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
774
- ],
775
- -1,
776
- )
777
-
778
- x_out2 = x_out2.flatten(3)
779
- return x_out2.type_as(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import json
3
+ import math
4
+ from collections import OrderedDict
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from loguru import logger
13
+ from torch import Tensor
14
+ from torch.nn import functional as F
15
+ from torch.nn.attention import SDPBackend, sdpa_kernel
16
+ from torch.utils.checkpoint import checkpoint
17
+ from transformers import AutoTokenizer
18
+
19
+ from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
20
+ from fish_speech.utils import RankedLogger
21
+
22
+ from .lora import LoraConfig, setup_lora
23
+
24
+ log = RankedLogger(__name__, rank_zero_only=True)
25
+
26
+
27
+ def find_multiple(n: int, k: int) -> int:
28
+ if n % k == 0:
29
+ return n
30
+ return n + k - (n % k)
31
+
32
+
33
+ @dataclass
34
+ class BaseModelArgs:
35
+ model_type: str = "base"
36
+
37
+ vocab_size: int = 32000
38
+ n_layer: int = 32
39
+ n_head: int = 32
40
+ dim: int = 4096
41
+ intermediate_size: int = None
42
+ n_local_heads: int = -1
43
+ head_dim: int = 64
44
+ rope_base: float = 10000
45
+ norm_eps: float = 1e-5
46
+ max_seq_len: int = 2048
47
+ dropout: float = 0.0
48
+ tie_word_embeddings: bool = True
49
+ attention_qkv_bias: bool = False
50
+
51
+ # Codebook configs
52
+ codebook_size: int = 160
53
+ num_codebooks: int = 4
54
+
55
+ # Gradient checkpointing
56
+ use_gradient_checkpointing: bool = True
57
+
58
+ # Initialize the model
59
+ initializer_range: float = 0.02
60
+
61
+ # Dummy vars
62
+ is_reward_model: bool = False
63
+ share_codebook_embeddings: bool = True
64
+ scale_codebook_embeddings: bool = False
65
+
66
+ def __post_init__(self):
67
+ if self.n_local_heads == -1:
68
+ self.n_local_heads = self.n_head
69
+ if self.intermediate_size is None:
70
+ hidden_dim = 4 * self.dim
71
+ n_hidden = int(2 * hidden_dim / 3)
72
+ self.intermediate_size = find_multiple(n_hidden, 256)
73
+ self.head_dim = self.dim // self.n_head
74
+
75
+ @staticmethod
76
+ def from_pretrained(path: str):
77
+ path = Path(path)
78
+
79
+ if path.is_dir():
80
+ path = path / "config.json"
81
+
82
+ with open(path, "r", encoding="utf-8") as f:
83
+ data = json.load(f)
84
+
85
+ match data["model_type"]:
86
+ case "naive":
87
+ cls = NaiveModelArgs
88
+ case "dual_ar":
89
+ cls = DualARModelArgs
90
+ case _:
91
+ raise ValueError(f"Unknown model type: {data['model_type']}")
92
+
93
+ return cls(**data)
94
+
95
+ def save(self, path: str):
96
+ with open(path, "w") as f:
97
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
98
+
99
+
100
+ @dataclass
101
+ class NaiveModelArgs(BaseModelArgs):
102
+ model_type: str = "naive"
103
+
104
+
105
+ @dataclass
106
+ class DualARModelArgs(BaseModelArgs):
107
+ model_type: str = "dual_ar"
108
+ n_fast_layer: int = 4
109
+ fast_dim: int | None = None
110
+ fast_n_head: int | None = None
111
+ fast_n_local_heads: int | None = None
112
+ fast_head_dim: int | None = None
113
+ fast_intermediate_size: int | None = None
114
+ fast_attention_qkv_bias: bool | None = None
115
+
116
+ def __post_init__(self):
117
+ super().__post_init__()
118
+
119
+ self.fast_dim = self.fast_dim or self.dim
120
+ self.fast_n_head = self.fast_n_head or self.n_head
121
+ self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
122
+ self.fast_head_dim = self.fast_head_dim or self.head_dim
123
+ self.fast_intermediate_size = (
124
+ self.fast_intermediate_size or self.intermediate_size
125
+ )
126
+ self.fast_attention_qkv_bias = (
127
+ self.fast_attention_qkv_bias
128
+ if self.fast_attention_qkv_bias is not None
129
+ else self.attention_qkv_bias
130
+ )
131
+
132
+
133
+ class KVCache(nn.Module):
134
+ def __init__(
135
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
136
+ ):
137
+ super().__init__()
138
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
139
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
140
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
141
+
142
+ def update(self, input_pos, k_val, v_val):
143
+ # input_pos: [S], k_val: [B, H, S, D]
144
+ assert input_pos.shape[0] == k_val.shape[2]
145
+
146
+ k_out = self.k_cache
147
+ v_out = self.v_cache
148
+ k_out[:, :, input_pos] = k_val
149
+ v_out[:, :, input_pos] = v_val
150
+
151
+ return k_out, v_out
152
+
153
+
154
+ @dataclass
155
+ class TransformerForwardResult:
156
+ token_logits: Tensor
157
+ codebook_logits: Tensor
158
+
159
+
160
+ @dataclass
161
+ class BaseTransformerForwardResult:
162
+ logits: Tensor
163
+ hidden_states: Tensor
164
+
165
+
166
+ class BaseTransformer(nn.Module):
167
+ def __init__(
168
+ self,
169
+ config: BaseModelArgs,
170
+ tokenizer: FishTokenizer | AutoTokenizer,
171
+ init_weights: bool = True,
172
+ ) -> None:
173
+ super().__init__()
174
+ self.config = config
175
+ self.tokenizer = tokenizer
176
+ self.semantic_token_ids = [
177
+ tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
178
+ ]
179
+
180
+ # Slow transformer
181
+ self.embeddings = nn.Embedding(
182
+ config.vocab_size,
183
+ config.dim,
184
+ )
185
+ self.codebook_embeddings = nn.Embedding(
186
+ config.codebook_size * config.num_codebooks,
187
+ config.dim,
188
+ )
189
+ self.layers = nn.ModuleList(
190
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
191
+ )
192
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
193
+
194
+ if self.config.tie_word_embeddings is False:
195
+ self.output = nn.Linear(
196
+ config.dim,
197
+ config.vocab_size,
198
+ bias=False,
199
+ )
200
+
201
+ self.register_buffer(
202
+ "freqs_cis",
203
+ precompute_freqs_cis(
204
+ config.max_seq_len,
205
+ config.dim // config.n_head,
206
+ config.rope_base,
207
+ ),
208
+ persistent=False,
209
+ )
210
+ self.register_buffer(
211
+ "causal_mask",
212
+ torch.tril(
213
+ torch.ones(
214
+ config.max_seq_len,
215
+ config.max_seq_len,
216
+ dtype=torch.bool,
217
+ )
218
+ ),
219
+ persistent=False,
220
+ )
221
+
222
+ # For kv cache
223
+ self.max_batch_size = -1
224
+ self.max_seq_len = -1
225
+
226
+ if init_weights:
227
+ self.apply(self._init_weights)
228
+
229
+ def setup_caches(
230
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
231
+ ):
232
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
233
+ return
234
+
235
+ head_dim = self.config.dim // self.config.n_head
236
+ max_seq_len = find_multiple(max_seq_len, 8)
237
+ self.max_seq_len = max_seq_len
238
+ self.max_batch_size = max_batch_size
239
+
240
+ for b in self.layers:
241
+ b.attention.kv_cache = KVCache(
242
+ max_batch_size,
243
+ max_seq_len,
244
+ self.config.n_local_heads,
245
+ head_dim,
246
+ dtype=dtype,
247
+ )
248
+
249
+ def embed(self, x: Tensor) -> Tensor:
250
+ vocab_embeds = [self.embeddings(x[:, 0])]
251
+ for i in range(self.config.num_codebooks):
252
+ emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
253
+ semantic_token_ids_tensor = torch.tensor(
254
+ self.semantic_token_ids, device=x.device
255
+ )
256
+ emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0
257
+
258
+ x = torch.stack(vocab_embeds, dim=3)
259
+ x = x.sum(dim=3)
260
+
261
+ return x
262
+
263
+ def forward(
264
+ self,
265
+ inp: Tensor,
266
+ key_padding_mask: Optional[Tensor] = None,
267
+ ) -> BaseTransformerForwardResult:
268
+ seq_len = inp.size(2)
269
+
270
+ # Here we want to merge the embeddings of the codebooks
271
+ x = self.embed(inp)
272
+
273
+ freqs_cis = self.freqs_cis[:seq_len]
274
+
275
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
276
+ # That is, FALSE means masked out
277
+ # To maintain consistency, key_padding_mask use TRUE to mask out
278
+ mask = None
279
+ if key_padding_mask is not None:
280
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
281
+ mask = mask & key_padding_mask[:, None, None, :].logical_not()
282
+
283
+ for layer in self.layers:
284
+ if self.config.use_gradient_checkpointing and self.training:
285
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
286
+ else:
287
+ x = layer(x, freqs_cis, mask)
288
+
289
+ # We got slow_out here
290
+ slow_out = self.norm(x)
291
+
292
+ if self.config.tie_word_embeddings:
293
+ token_logits = F.linear(slow_out, self.embeddings.weight)
294
+ else:
295
+ token_logits = self.output(slow_out)
296
+
297
+ return BaseTransformerForwardResult(
298
+ logits=token_logits,
299
+ hidden_states=x,
300
+ )
301
+
302
+ def forward_generate(
303
+ self,
304
+ inp: Tensor,
305
+ input_pos: Optional[Tensor] = None,
306
+ vq_masks: Optional[Tensor] = None, # this is not used in fact
307
+ return_all: bool = False,
308
+ ) -> BaseTransformerForwardResult:
309
+ # This is used for generation, optimized for torch compile
310
+ # assert (
311
+ # self.max_seq_len != -1 and self.max_batch_size != -1
312
+ # ), "Please call setup_caches before forward_generate"
313
+
314
+ embeds = []
315
+ for i in range(self.config.num_codebooks):
316
+ if self.config.share_codebook_embeddings:
317
+ _tokens = inp[:, i + 1] + i * self.config.codebook_size
318
+ else:
319
+ _tokens = inp[:, i + 1]
320
+
321
+ emb = self.codebook_embeddings(_tokens)
322
+ embeds.append(emb)
323
+
324
+ vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
325
+ # if self.config.use_codebook_mlp:
326
+ # vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks
327
+ # vq_embeds_sum = self.codebook_mlp(vq_embeds_sum)
328
+
329
+ vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
330
+ inp[:, 0] <= self.tokenizer.semantic_end_id
331
+ )
332
+
333
+ vq_embeds_sum[~vq_masks] = 0
334
+ x = self.embeddings(inp[:, 0]) + vq_embeds_sum
335
+
336
+ if input_pos is None:
337
+ input_pos = torch.arange(inp.shape[-1], device=x.device)
338
+ max_seq_len = inp.shape[-1]
339
+ else:
340
+ max_seq_len = self.max_seq_len
341
+
342
+ mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
343
+ freqs_cis = self.freqs_cis[input_pos]
344
+
345
+ for layer in self.layers:
346
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
347
+
348
+ # If prefill, we only calculate the logits of last token
349
+ if x.size(1) > 1 and not return_all:
350
+ x = x[:, -1:]
351
+
352
+ # We got slow_out here
353
+ slow_out = self.norm(x)
354
+
355
+ if self.config.is_reward_model:
356
+ token_logits = self.score_output(slow_out)
357
+ elif self.config.tie_word_embeddings:
358
+ token_logits = F.linear(slow_out, self.embeddings.weight)
359
+ else:
360
+ token_logits = self.output(slow_out)
361
+
362
+ return BaseTransformerForwardResult(
363
+ logits=token_logits,
364
+ hidden_states=x,
365
+ )
366
+
367
+ def _init_weights(self, module):
368
+ std = self.config.initializer_range
369
+ if isinstance(module, nn.Linear):
370
+ module.weight.data.normal_(mean=0.0, std=std)
371
+ if module.bias is not None:
372
+ module.bias.data.zero_()
373
+ elif isinstance(module, nn.Embedding):
374
+ module.weight.data.normal_(mean=0.0, std=std)
375
+ if module.padding_idx is not None:
376
+ module.weight.data[module.padding_idx].zero_()
377
+
378
+ @staticmethod
379
+ def from_pretrained(
380
+ path: str,
381
+ load_weights: bool = False,
382
+ max_length: int | None = None,
383
+ lora_config: LoraConfig | None = None,
384
+ rope_base: int | None = None,
385
+ is_agent: bool = False,
386
+ ) -> "BaseTransformer":
387
+ config = BaseModelArgs.from_pretrained(str(path))
388
+ if max_length is not None:
389
+ config.max_seq_len = max_length
390
+ log.info(f"Override max_seq_len to {max_length}")
391
+
392
+ if rope_base is not None:
393
+ config.rope_base = rope_base
394
+ log.info(f"Override rope_base to {rope_base}")
395
+
396
+ match config.model_type:
397
+ case "naive":
398
+ model_cls = NaiveTransformer
399
+ case "dual_ar":
400
+ model_cls = DualARTransformer
401
+ case _:
402
+ raise ValueError(f"Unknown model type: {config.model_type}")
403
+
404
+ if is_agent:
405
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
406
+ else:
407
+ tokenizer_path = str(path) + "/tokenizer.tiktoken"
408
+ tokenizer = FishTokenizer(tokenizer_path)
409
+
410
+ log.info(f"Loading model from {path}, config: {config}")
411
+ model = model_cls(config, tokenizer=tokenizer)
412
+
413
+ if lora_config is not None:
414
+ setup_lora(model, lora_config)
415
+ log.info(f"LoRA setup: {lora_config}")
416
+
417
+ if load_weights is False:
418
+ log.info("Randomly initialized model")
419
+ else:
420
+
421
+ if "int8" in str(Path(path)):
422
+ logger.info("Using int8 weight-only quantization!")
423
+ from tools.llama.quantize import WeightOnlyInt8QuantHandler
424
+
425
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
426
+ model = simple_quantizer.convert_for_runtime()
427
+
428
+ if "int4" in str(Path(path)):
429
+ logger.info("Using int4 quantization!")
430
+ path_comps = path.name.split("-")
431
+ assert path_comps[-2].startswith("g")
432
+ groupsize = int(path_comps[-2][1:])
433
+ from tools.llama.quantize import WeightOnlyInt4QuantHandler
434
+
435
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
436
+ model = simple_quantizer.convert_for_runtime()
437
+
438
+ weights = torch.load(
439
+ Path(path) / "model.pth",
440
+ map_location="cpu",
441
+ mmap=True,
442
+ weights_only=True,
443
+ )
444
+
445
+ if "state_dict" in weights:
446
+ logger.warning(
447
+ "Using a TextToSemantic LightningModule checkpoint, "
448
+ "please make sure it is a full model, not a LoRA model."
449
+ )
450
+ weights = weights["state_dict"]
451
+
452
+ if next(iter(weights.keys())).startswith("model."):
453
+ logger.info(
454
+ f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
455
+ )
456
+ new_weights = OrderedDict()
457
+ for k, v in weights.items():
458
+ new_weights[k.replace("model.", "")] = v
459
+ weights = new_weights
460
+
461
+ # Verify the name and shape of parameters since strict=False in load_state_dict.
462
+ for k, v in model.named_parameters():
463
+ if k not in weights:
464
+ logger.warning(f"No weight for {k}")
465
+ elif v.shape != weights[k].shape:
466
+ logger.warning(
467
+ f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
468
+ )
469
+
470
+ err = model.load_state_dict(weights, strict=False, assign=True)
471
+ log.info(f"Loaded weights with error: {err}")
472
+
473
+ return model
474
+
475
+ def save_pretrained(self, path: str, drop_lora: bool = False):
476
+ path = Path(path)
477
+ path.mkdir(parents=True, exist_ok=True)
478
+
479
+ self.config.save(path / "config.json")
480
+ state_dict = self.state_dict()
481
+
482
+ if drop_lora:
483
+ for key in list(state_dict.keys()):
484
+ if "lora" not in key:
485
+ continue
486
+
487
+ state_dict.pop(key)
488
+ log.info(f"Drop LoRA parameter: {key}")
489
+
490
+ torch.save(state_dict, path / "model.pth")
491
+ self.tokenizer.save_pretrained(path)
492
+
493
+
494
+ class NaiveTransformer(BaseTransformer):
495
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
496
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
497
+
498
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
499
+ self.codebook_output = nn.Linear(
500
+ config.dim,
501
+ config.codebook_size * config.num_codebooks,
502
+ bias=False,
503
+ )
504
+
505
+ self.apply(self._init_weights)
506
+
507
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
508
+ token_logits = result.logits
509
+ x = result.hidden_states
510
+
511
+ # Codebook
512
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
513
+ codebook_logits = rearrange(
514
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
515
+ )
516
+
517
+ return TransformerForwardResult(
518
+ token_logits=token_logits,
519
+ codebook_logits=codebook_logits,
520
+ )
521
+
522
+ def forward(
523
+ self,
524
+ inp: Tensor,
525
+ key_padding_mask: Optional[Tensor] = None,
526
+ ) -> TransformerForwardResult:
527
+ result = super().forward(
528
+ inp=inp,
529
+ key_padding_mask=key_padding_mask,
530
+ )
531
+ return self.decode(result)
532
+
533
+ def forward_generate(
534
+ self, x: Tensor, input_pos: Optional[Tensor] = None
535
+ ) -> TransformerForwardResult:
536
+ result = super().forward_generate(x, input_pos)
537
+ return self.decode(result)
538
+
539
+
540
+ class DualARTransformer(BaseTransformer):
541
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
542
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
543
+
544
+ # Project to fast dim if needed
545
+ if config.fast_dim is not None and config.fast_dim != config.dim:
546
+ self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
547
+ else:
548
+ self.fast_project_in = nn.Identity()
549
+
550
+ # Fast transformer
551
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
552
+
553
+ # The equivalent bs is so large that sdpa doesn't work
554
+ override_config = dataclasses.replace(
555
+ config,
556
+ dim=config.fast_dim,
557
+ n_head=config.fast_n_head,
558
+ n_local_heads=config.fast_n_local_heads,
559
+ head_dim=config.fast_head_dim,
560
+ intermediate_size=config.fast_intermediate_size,
561
+ attention_qkv_bias=config.fast_attention_qkv_bias,
562
+ )
563
+
564
+ self.fast_layers = nn.ModuleList(
565
+ TransformerBlock(override_config, use_sdpa=False)
566
+ for _ in range(config.n_fast_layer)
567
+ )
568
+ self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
569
+ self.fast_output = nn.Linear(
570
+ config.fast_dim,
571
+ config.codebook_size,
572
+ bias=False,
573
+ )
574
+
575
+ self.register_buffer(
576
+ "fast_freqs_cis",
577
+ precompute_freqs_cis(
578
+ config.num_codebooks,
579
+ config.fast_dim // config.fast_n_head,
580
+ config.rope_base,
581
+ ),
582
+ persistent=False,
583
+ )
584
+ self.apply(self._init_weights)
585
+
586
+ def setup_caches(
587
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
588
+ ):
589
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
590
+
591
+ head_dim = self.config.fast_dim // self.config.fast_n_head
592
+
593
+ # Fast transformer
594
+ # The max seq len here is the number of codebooks
595
+ for b in self.fast_layers:
596
+ b.attention.kv_cache = KVCache(
597
+ max_batch_size,
598
+ self.config.num_codebooks,
599
+ self.config.fast_n_local_heads,
600
+ head_dim,
601
+ dtype=dtype,
602
+ )
603
+
604
+ def forward(
605
+ self,
606
+ inp: Tensor,
607
+ key_padding_mask: Optional[Tensor] = None,
608
+ ) -> TransformerForwardResult:
609
+ parent_result = super().forward(inp, key_padding_mask)
610
+ token_logits = parent_result.logits
611
+ x = parent_result.hidden_states
612
+ x = self.fast_project_in(x)
613
+
614
+ # Fast transformer
615
+ fast_seq_len = self.config.num_codebooks
616
+ fast_mask = self.causal_mask[
617
+ None, None, :fast_seq_len, :fast_seq_len
618
+ ] # (B, N, Q, K)
619
+
620
+ # Drop the last token and rotate left
621
+ codebooks = inp[:, 1:-1, 1:]
622
+ codebooks = F.pad(codebooks, (0, 1), value=0)
623
+ codebook_embeddings = self.fast_embeddings(codebooks)
624
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
625
+ b, s = x.size(0), x.size(2)
626
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
627
+
628
+ # Remove padded part
629
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
630
+ codebook_mask = (codebooks == 0).all(dim=-1)
631
+
632
+ if torch.all(codebook_mask):
633
+ # If all codebooks are padded, we keep first 8 to make sure the model runs
634
+ codebook_mask[:8] = False
635
+
636
+ x_bs, x_len = x.size(0), x.size(1)
637
+ x = x[~codebook_mask]
638
+
639
+ for layer in self.fast_layers:
640
+ if self.config.use_gradient_checkpointing and self.training:
641
+ x = checkpoint(
642
+ layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
643
+ )
644
+ else:
645
+ x = layer(x, self.fast_freqs_cis, fast_mask)
646
+
647
+ # unflatten the batch and num_codebooks
648
+ fast_out = self.fast_norm(x)
649
+ codebook_logits = self.fast_output(fast_out)
650
+
651
+ # Re-pad the codebook_logits
652
+ buffer = torch.zeros(
653
+ x_bs,
654
+ x_len,
655
+ codebook_logits.size(-1),
656
+ device=codebook_logits.device,
657
+ dtype=codebook_logits.dtype,
658
+ )
659
+ buffer[~codebook_mask] = codebook_logits
660
+ codebook_logits = buffer
661
+
662
+ assert codebook_logits.shape[1] == self.config.num_codebooks
663
+ codebook_logits = rearrange(
664
+ codebook_logits,
665
+ "(b s) n d -> b s n d",
666
+ b=b,
667
+ s=s,
668
+ n=self.config.num_codebooks,
669
+ )
670
+
671
+ return TransformerForwardResult(
672
+ token_logits=token_logits,
673
+ codebook_logits=codebook_logits,
674
+ )
675
+
676
+ def forward_generate_fast(
677
+ self, x: Tensor, input_pos: Optional[Tensor] = None
678
+ ) -> Tensor:
679
+ # Fast transformer
680
+ x = x.view(1, 1, -1)
681
+
682
+ fast_mask = self.causal_mask[
683
+ None, None, input_pos, : self.config.num_codebooks
684
+ ] # (B, N, Q, K)
685
+ fast_freqs_cis = self.fast_freqs_cis[input_pos]
686
+
687
+ for layer in self.fast_layers:
688
+ x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
689
+
690
+ # unflatten the batch and num_codebooks
691
+ fast_out = self.fast_norm(x) # only take the last token
692
+ codebook_logits = self.fast_output(fast_out)
693
+
694
+ return codebook_logits
695
+
696
+ def forward_generate(
697
+ self,
698
+ x: Tensor,
699
+ input_pos: Optional[Tensor] = None,
700
+ vq_masks: Optional[Tensor] = None,
701
+ ) -> TransformerForwardResult:
702
+ x = super().forward_generate(x, input_pos, vq_masks)
703
+ x.hidden_states = self.fast_project_in(x.hidden_states)
704
+ return x
705
+
706
+
707
+ class TransformerBlock(nn.Module):
708
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
709
+ super().__init__()
710
+ self.attention = Attention(config, use_sdpa=use_sdpa)
711
+ self.feed_forward = FeedForward(config)
712
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
713
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
714
+
715
+ def forward(
716
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
717
+ ) -> Tensor:
718
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
719
+ out = h + self.feed_forward(self.ffn_norm(h))
720
+ return out
721
+
722
+
723
+ class Attention(nn.Module):
724
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
725
+ super().__init__()
726
+ assert config.dim % config.n_head == 0
727
+
728
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
729
+ # key, query, value projections for all heads, but in a batch
730
+ self.wqkv = nn.Linear(
731
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
732
+ )
733
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
734
+ self.kv_cache = None
735
+
736
+ self.dropout = config.dropout
737
+ self.n_head = config.n_head
738
+ self.head_dim = config.head_dim
739
+ self.n_local_heads = config.n_local_heads
740
+ self.dim = config.dim
741
+ self.use_sdpa = use_sdpa
742
+ self._register_load_state_dict_pre_hook(self.load_hook)
743
+
744
+ def load_hook(self, state_dict, prefix, *args):
745
+ if prefix + "wq.weight" in state_dict:
746
+ wq = state_dict.pop(prefix + "wq.weight")
747
+ wk = state_dict.pop(prefix + "wk.weight")
748
+ wv = state_dict.pop(prefix + "wv.weight")
749
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
750
+
751
+ def forward(
752
+ self,
753
+ x: Tensor,
754
+ freqs_cis: Tensor,
755
+ mask: Tensor,
756
+ input_pos: Optional[Tensor] = None,
757
+ ) -> Tensor:
758
+ bsz, seqlen, _ = x.shape
759
+
760
+ kv_size = self.n_local_heads * self.head_dim
761
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
762
+
763
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
764
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
765
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
766
+
767
+ q = apply_rotary_emb(q, freqs_cis)
768
+ k = apply_rotary_emb(k, freqs_cis)
769
+
770
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
771
+
772
+ if self.kv_cache is not None:
773
+ k, v = self.kv_cache.update(input_pos, k, v)
774
+
775
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
776
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
777
+
778
+ if self.use_sdpa:
779
+ if mask is None:
780
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
781
+ y = F.scaled_dot_product_attention(
782
+ q,
783
+ k,
784
+ v,
785
+ dropout_p=self.dropout if self.training else 0.0,
786
+ is_causal=True,
787
+ # No third party attn_mask here to use flash_attention
788
+ )
789
+ else:
790
+ y = F.scaled_dot_product_attention(
791
+ q,
792
+ k,
793
+ v,
794
+ attn_mask=mask,
795
+ dropout_p=self.dropout if self.training else 0.0,
796
+ )
797
+ else:
798
+ y = self.eq_scaled_dot_product_attention(
799
+ q,
800
+ k,
801
+ v,
802
+ attn_mask=mask,
803
+ dropout_p=self.dropout if self.training else 0.0,
804
+ )
805
+
806
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
807
+
808
+ return self.wo(y)
809
+
810
+ def eq_scaled_dot_product_attention(
811
+ self,
812
+ query,
813
+ key,
814
+ value,
815
+ attn_mask=None,
816
+ dropout_p=0.0,
817
+ ) -> torch.Tensor:
818
+ # This is a standard scaled dot product attention
819
+ # It's low efficient, but it doesn't raise cuda error
820
+
821
+ L, S = query.size(-2), key.size(-2)
822
+ scale_factor = 1 / math.sqrt(query.size(-1))
823
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
824
+
825
+ if attn_mask is not None:
826
+ if attn_mask.dtype == torch.bool:
827
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
828
+ else:
829
+ attn_bias += attn_mask
830
+
831
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
832
+ attn_weight += attn_bias
833
+ attn_weight = torch.softmax(attn_weight, dim=-1)
834
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
835
+
836
+ return attn_weight @ value
837
+
838
+
839
+ class FeedForward(nn.Module):
840
+ def __init__(self, config: BaseModelArgs) -> None:
841
+ super().__init__()
842
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
843
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
844
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
845
+
846
+ def forward(self, x: Tensor) -> Tensor:
847
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
848
+
849
+
850
+ class RMSNorm(nn.Module):
851
+ def __init__(self, dim: int, eps: float = 1e-5):
852
+ super().__init__()
853
+ self.eps = eps
854
+ self.weight = nn.Parameter(torch.ones(dim))
855
+
856
+ def _norm(self, x):
857
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
858
+
859
+ def forward(self, x: Tensor) -> Tensor:
860
+ output = self._norm(x.float()).type_as(x)
861
+ return output * self.weight
862
+
863
+
864
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
865
+ freqs = 1.0 / (
866
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
867
+ )
868
+ t = torch.arange(seq_len, device=freqs.device)
869
+ freqs = torch.outer(t, freqs)
870
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
871
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
872
+ return cache.to(dtype=torch.bfloat16)
873
+
874
+
875
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
876
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
877
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
878
+ x_out2 = torch.stack(
879
+ [
880
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
881
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
882
+ ],
883
+ -1,
884
+ )
885
+
886
+ x_out2 = x_out2.flatten(3)
887
+ return x_out2.type_as(x)
fish_speech/models/text2semantic/lora.py CHANGED
@@ -1,92 +1,92 @@
1
- from dataclasses import dataclass
2
-
3
- import loralib as lora
4
-
5
-
6
- @dataclass
7
- class LoraConfig:
8
- r: int
9
- lora_alpha: float
10
- lora_dropout: float = 0.0
11
-
12
-
13
- def setup_lora(model, lora_config):
14
- # Replace the embedding layer with a LoRA layer
15
- model.embeddings = lora.Embedding(
16
- num_embeddings=model.embeddings.num_embeddings,
17
- embedding_dim=model.embeddings.embedding_dim,
18
- padding_idx=model.embeddings.padding_idx,
19
- r=lora_config.r,
20
- lora_alpha=lora_config.lora_alpha,
21
- )
22
-
23
- model.codebook_embeddings = lora.Embedding(
24
- num_embeddings=model.codebook_embeddings.num_embeddings,
25
- embedding_dim=model.codebook_embeddings.embedding_dim,
26
- padding_idx=model.codebook_embeddings.padding_idx,
27
- r=lora_config.r,
28
- lora_alpha=lora_config.lora_alpha,
29
- )
30
-
31
- # Replace output layer with a LoRA layer
32
- linears = [(model, "output")]
33
-
34
- # Replace all linear layers with LoRA layers
35
- for layer in model.layers:
36
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
37
- linears.extend(
38
- [
39
- (layer.feed_forward, "w1"),
40
- (layer.feed_forward, "w2"),
41
- (layer.feed_forward, "w3"),
42
- ]
43
- )
44
-
45
- if hasattr(model, "fast_layers"):
46
- model.fast_embeddings = lora.Embedding(
47
- num_embeddings=model.fast_embeddings.num_embeddings,
48
- embedding_dim=model.fast_embeddings.embedding_dim,
49
- padding_idx=model.fast_embeddings.padding_idx,
50
- r=lora_config.r,
51
- lora_alpha=lora_config.lora_alpha,
52
- )
53
-
54
- # Dual-AR model
55
- linears.append((model, "fast_output"))
56
-
57
- for layer in model.fast_layers:
58
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
59
- linears.extend(
60
- [
61
- (layer.feed_forward, "w1"),
62
- (layer.feed_forward, "w2"),
63
- (layer.feed_forward, "w3"),
64
- ]
65
- )
66
-
67
- for module, layer in linears:
68
- updated_linear = lora.Linear(
69
- in_features=getattr(module, layer).in_features,
70
- out_features=getattr(module, layer).out_features,
71
- bias=getattr(module, layer).bias,
72
- r=lora_config.r,
73
- lora_alpha=lora_config.lora_alpha,
74
- lora_dropout=lora_config.lora_dropout,
75
- )
76
- setattr(module, layer, updated_linear)
77
-
78
- # Mark only the LoRA layers as trainable
79
- lora.mark_only_lora_as_trainable(model, bias="none")
80
-
81
-
82
- def get_merged_state_dict(model):
83
- # This line will merge the state dict of the model and the LoRA parameters
84
- model.eval()
85
-
86
- # Then we need to remove the LoRA parameters from the state dict
87
- state_dict = model.state_dict()
88
- for name in list(state_dict.keys()):
89
- if "lora" in name:
90
- state_dict.pop(name)
91
-
92
- return state_dict
 
1
+ from dataclasses import dataclass
2
+
3
+ import loralib as lora
4
+
5
+
6
+ @dataclass
7
+ class LoraConfig:
8
+ r: int
9
+ lora_alpha: float
10
+ lora_dropout: float = 0.0
11
+
12
+
13
+ def setup_lora(model, lora_config):
14
+ # Replace the embedding layer with a LoRA layer
15
+ model.embeddings = lora.Embedding(
16
+ num_embeddings=model.embeddings.num_embeddings,
17
+ embedding_dim=model.embeddings.embedding_dim,
18
+ padding_idx=model.embeddings.padding_idx,
19
+ r=lora_config.r,
20
+ lora_alpha=lora_config.lora_alpha,
21
+ )
22
+
23
+ model.codebook_embeddings = lora.Embedding(
24
+ num_embeddings=model.codebook_embeddings.num_embeddings,
25
+ embedding_dim=model.codebook_embeddings.embedding_dim,
26
+ padding_idx=model.codebook_embeddings.padding_idx,
27
+ r=lora_config.r,
28
+ lora_alpha=lora_config.lora_alpha,
29
+ )
30
+
31
+ # Replace output layer with a LoRA layer
32
+ linears = [(model, "output")]
33
+
34
+ # Replace all linear layers with LoRA layers
35
+ for layer in model.layers:
36
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
37
+ linears.extend(
38
+ [
39
+ (layer.feed_forward, "w1"),
40
+ (layer.feed_forward, "w2"),
41
+ (layer.feed_forward, "w3"),
42
+ ]
43
+ )
44
+
45
+ if hasattr(model, "fast_layers"):
46
+ model.fast_embeddings = lora.Embedding(
47
+ num_embeddings=model.fast_embeddings.num_embeddings,
48
+ embedding_dim=model.fast_embeddings.embedding_dim,
49
+ padding_idx=model.fast_embeddings.padding_idx,
50
+ r=lora_config.r,
51
+ lora_alpha=lora_config.lora_alpha,
52
+ )
53
+
54
+ # Dual-AR model
55
+ linears.append((model, "fast_output"))
56
+
57
+ for layer in model.fast_layers:
58
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
59
+ linears.extend(
60
+ [
61
+ (layer.feed_forward, "w1"),
62
+ (layer.feed_forward, "w2"),
63
+ (layer.feed_forward, "w3"),
64
+ ]
65
+ )
66
+
67
+ for module, layer in linears:
68
+ updated_linear = lora.Linear(
69
+ in_features=getattr(module, layer).in_features,
70
+ out_features=getattr(module, layer).out_features,
71
+ bias=getattr(module, layer).bias,
72
+ r=lora_config.r,
73
+ lora_alpha=lora_config.lora_alpha,
74
+ lora_dropout=lora_config.lora_dropout,
75
+ )
76
+ setattr(module, layer, updated_linear)
77
+
78
+ # Mark only the LoRA layers as trainable
79
+ lora.mark_only_lora_as_trainable(model, bias="none")
80
+
81
+
82
+ def get_merged_state_dict(model):
83
+ # This line will merge the state dict of the model and the LoRA parameters
84
+ model.eval()
85
+
86
+ # Then we need to remove the LoRA parameters from the state dict
87
+ state_dict = model.state_dict()
88
+ for name in list(state_dict.keys()):
89
+ if "lora" in name:
90
+ state_dict.pop(name)
91
+
92
+ return state_dict
fish_speech/models/vqgan/lit_module.py DELETED
@@ -1,442 +0,0 @@
1
- import itertools
2
- import math
3
- from typing import Any, Callable
4
-
5
- import lightning as L
6
- import torch
7
- import torch.nn.functional as F
8
- import wandb
9
- from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
10
- from matplotlib import pyplot as plt
11
- from torch import nn
12
-
13
- from fish_speech.models.vqgan.modules.discriminator import Discriminator
14
- from fish_speech.models.vqgan.modules.wavenet import WaveNet
15
- from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
16
-
17
-
18
- class VQGAN(L.LightningModule):
19
- def __init__(
20
- self,
21
- optimizer: Callable,
22
- lr_scheduler: Callable,
23
- encoder: WaveNet,
24
- quantizer: nn.Module,
25
- decoder: WaveNet,
26
- discriminator: Discriminator,
27
- vocoder: nn.Module,
28
- encode_mel_transform: nn.Module,
29
- gt_mel_transform: nn.Module,
30
- weight_adv: float = 1.0,
31
- weight_vq: float = 1.0,
32
- weight_mel: float = 1.0,
33
- sampling_rate: int = 44100,
34
- freeze_encoder: bool = False,
35
- ):
36
- super().__init__()
37
-
38
- # Model parameters
39
- self.optimizer_builder = optimizer
40
- self.lr_scheduler_builder = lr_scheduler
41
-
42
- # Modules
43
- self.encoder = encoder
44
- self.quantizer = quantizer
45
- self.decoder = decoder
46
- self.vocoder = vocoder
47
- self.discriminator = discriminator
48
- self.encode_mel_transform = encode_mel_transform
49
- self.gt_mel_transform = gt_mel_transform
50
-
51
- # A simple linear layer to project quality to condition channels
52
- self.quality_projection = nn.Linear(1, 768)
53
-
54
- # Freeze vocoder
55
- for param in self.vocoder.parameters():
56
- param.requires_grad = False
57
-
58
- # Loss weights
59
- self.weight_adv = weight_adv
60
- self.weight_vq = weight_vq
61
- self.weight_mel = weight_mel
62
-
63
- # Other parameters
64
- self.sampling_rate = sampling_rate
65
-
66
- # Disable strict loading
67
- self.strict_loading = False
68
-
69
- # If encoder is frozen
70
- if freeze_encoder:
71
- for param in self.encoder.parameters():
72
- param.requires_grad = False
73
-
74
- for param in self.quantizer.parameters():
75
- param.requires_grad = False
76
-
77
- self.automatic_optimization = False
78
-
79
- def on_save_checkpoint(self, checkpoint):
80
- # Do not save vocoder
81
- state_dict = checkpoint["state_dict"]
82
- for name in list(state_dict.keys()):
83
- if "vocoder" in name:
84
- state_dict.pop(name)
85
-
86
- def configure_optimizers(self):
87
- optimizer_generator = self.optimizer_builder(
88
- itertools.chain(
89
- self.encoder.parameters(),
90
- self.quantizer.parameters(),
91
- self.decoder.parameters(),
92
- self.quality_projection.parameters(),
93
- )
94
- )
95
- optimizer_discriminator = self.optimizer_builder(
96
- self.discriminator.parameters()
97
- )
98
-
99
- lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
100
- lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
101
-
102
- return (
103
- {
104
- "optimizer": optimizer_generator,
105
- "lr_scheduler": {
106
- "scheduler": lr_scheduler_generator,
107
- "interval": "step",
108
- "name": "optimizer/generator",
109
- },
110
- },
111
- {
112
- "optimizer": optimizer_discriminator,
113
- "lr_scheduler": {
114
- "scheduler": lr_scheduler_discriminator,
115
- "interval": "step",
116
- "name": "optimizer/discriminator",
117
- },
118
- },
119
- )
120
-
121
- def training_step(self, batch, batch_idx):
122
- optim_g, optim_d = self.optimizers()
123
-
124
- audios, audio_lengths = batch["audios"], batch["audio_lengths"]
125
-
126
- audios = audios.float()
127
- audios = audios[:, None, :]
128
-
129
- with torch.no_grad():
130
- encoded_mels = self.encode_mel_transform(audios)
131
- gt_mels = self.gt_mel_transform(audios)
132
- quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
133
- quality = quality.unsqueeze(-1)
134
-
135
- mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
136
- mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
137
- mel_masks_float_conv = mel_masks[:, None, :].float()
138
- gt_mels = gt_mels * mel_masks_float_conv
139
- encoded_mels = encoded_mels * mel_masks_float_conv
140
-
141
- # Encode
142
- encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
143
-
144
- # Quantize
145
- vq_result = self.quantizer(encoded_features)
146
- loss_vq = getattr("vq_result", "loss", 0.0)
147
- vq_recon_features = vq_result.z * mel_masks_float_conv
148
- vq_recon_features = (
149
- vq_recon_features + self.quality_projection(quality)[:, :, None]
150
- )
151
-
152
- # VQ Decode
153
- gen_mel = (
154
- self.decoder(
155
- torch.randn_like(vq_recon_features) * mel_masks_float_conv,
156
- condition=vq_recon_features,
157
- )
158
- * mel_masks_float_conv
159
- )
160
-
161
- # Discriminator
162
- real_logits = self.discriminator(gt_mels)
163
- fake_logits = self.discriminator(gen_mel.detach())
164
- d_mask = F.interpolate(
165
- mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
166
- )
167
-
168
- loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
169
- loss_fake = avg_with_mask(fake_logits**2, d_mask)
170
-
171
- loss_d = loss_real + loss_fake
172
-
173
- self.log(
174
- "train/discriminator/loss",
175
- loss_d,
176
- on_step=True,
177
- on_epoch=False,
178
- prog_bar=True,
179
- logger=True,
180
- )
181
-
182
- # Discriminator backward
183
- optim_d.zero_grad()
184
- self.manual_backward(loss_d)
185
- self.clip_gradients(
186
- optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
187
- )
188
- optim_d.step()
189
-
190
- # Mel Loss, applying l1, using a weighted sum
191
- mel_distance = (
192
- gen_mel - gt_mels
193
- ).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
194
- loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
195
- loss_mel_mid_freq = avg_with_mask(
196
- mel_distance[:, 40:70, :], mel_masks_float_conv
197
- )
198
- loss_mel_high_freq = avg_with_mask(
199
- mel_distance[:, 70:, :], mel_masks_float_conv
200
- )
201
- loss_mel = (
202
- loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
203
- )
204
-
205
- # Adversarial Loss
206
- fake_logits = self.discriminator(gen_mel)
207
- loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
208
-
209
- # Total loss
210
- loss = (
211
- self.weight_vq * loss_vq
212
- + self.weight_mel * loss_mel
213
- + self.weight_adv * loss_adv
214
- )
215
-
216
- # Log losses
217
- self.log(
218
- "train/generator/loss",
219
- loss,
220
- on_step=True,
221
- on_epoch=False,
222
- prog_bar=True,
223
- logger=True,
224
- )
225
- self.log(
226
- "train/generator/loss_vq",
227
- loss_vq,
228
- on_step=True,
229
- on_epoch=False,
230
- prog_bar=False,
231
- logger=True,
232
- )
233
- self.log(
234
- "train/generator/loss_mel",
235
- loss_mel,
236
- on_step=True,
237
- on_epoch=False,
238
- prog_bar=False,
239
- logger=True,
240
- )
241
- self.log(
242
- "train/generator/loss_adv",
243
- loss_adv,
244
- on_step=True,
245
- on_epoch=False,
246
- prog_bar=False,
247
- logger=True,
248
- )
249
-
250
- # Generator backward
251
- optim_g.zero_grad()
252
- self.manual_backward(loss)
253
- self.clip_gradients(
254
- optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
255
- )
256
- optim_g.step()
257
-
258
- scheduler_g, scheduler_d = self.lr_schedulers()
259
- scheduler_g.step()
260
- scheduler_d.step()
261
-
262
- def validation_step(self, batch: Any, batch_idx: int):
263
- audios, audio_lengths = batch["audios"], batch["audio_lengths"]
264
-
265
- audios = audios.float()
266
- audios = audios[:, None, :]
267
-
268
- encoded_mels = self.encode_mel_transform(audios)
269
- gt_mels = self.gt_mel_transform(audios)
270
-
271
- mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
272
- mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
273
- mel_masks_float_conv = mel_masks[:, None, :].float()
274
- gt_mels = gt_mels * mel_masks_float_conv
275
- encoded_mels = encoded_mels * mel_masks_float_conv
276
-
277
- # Encode
278
- encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
279
-
280
- # Quantize
281
- vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
282
- vq_recon_features = (
283
- vq_recon_features
284
- + self.quality_projection(
285
- torch.ones(
286
- vq_recon_features.shape[0], 1, device=vq_recon_features.device
287
- )
288
- * 2
289
- )[:, :, None]
290
- )
291
-
292
- # VQ Decode
293
- gen_aux_mels = (
294
- self.decoder(
295
- torch.randn_like(vq_recon_features) * mel_masks_float_conv,
296
- condition=vq_recon_features,
297
- )
298
- * mel_masks_float_conv
299
- )
300
- loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
301
-
302
- self.log(
303
- "val/loss_mel",
304
- loss_mel,
305
- on_step=False,
306
- on_epoch=True,
307
- prog_bar=False,
308
- logger=True,
309
- sync_dist=True,
310
- )
311
-
312
- recon_audios = self.vocoder(gt_mels)
313
- gen_aux_audios = self.vocoder(gen_aux_mels)
314
-
315
- # only log the first batch
316
- if batch_idx != 0:
317
- return
318
-
319
- for idx, (
320
- gt_mel,
321
- gen_aux_mel,
322
- audio,
323
- gen_aux_audio,
324
- recon_audio,
325
- audio_len,
326
- ) in enumerate(
327
- zip(
328
- gt_mels,
329
- gen_aux_mels,
330
- audios.cpu().float(),
331
- gen_aux_audios.cpu().float(),
332
- recon_audios.cpu().float(),
333
- audio_lengths,
334
- )
335
- ):
336
- if idx > 4:
337
- break
338
-
339
- mel_len = audio_len // self.gt_mel_transform.hop_length
340
-
341
- image_mels = plot_mel(
342
- [
343
- gt_mel[:, :mel_len],
344
- gen_aux_mel[:, :mel_len],
345
- ],
346
- [
347
- "Ground-Truth",
348
- "Auxiliary",
349
- ],
350
- )
351
-
352
- if isinstance(self.logger, WandbLogger):
353
- self.logger.experiment.log(
354
- {
355
- "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
356
- "wavs": [
357
- wandb.Audio(
358
- audio[0, :audio_len],
359
- sample_rate=self.sampling_rate,
360
- caption="gt",
361
- ),
362
- wandb.Audio(
363
- gen_aux_audio[0, :audio_len],
364
- sample_rate=self.sampling_rate,
365
- caption="aux",
366
- ),
367
- wandb.Audio(
368
- recon_audio[0, :audio_len],
369
- sample_rate=self.sampling_rate,
370
- caption="recon",
371
- ),
372
- ],
373
- },
374
- )
375
-
376
- if isinstance(self.logger, TensorBoardLogger):
377
- self.logger.experiment.add_figure(
378
- f"sample-{idx}/mels",
379
- image_mels,
380
- global_step=self.global_step,
381
- )
382
- self.logger.experiment.add_audio(
383
- f"sample-{idx}/wavs/gt",
384
- audio[0, :audio_len],
385
- self.global_step,
386
- sample_rate=self.sampling_rate,
387
- )
388
- self.logger.experiment.add_audio(
389
- f"sample-{idx}/wavs/gen",
390
- gen_aux_audio[0, :audio_len],
391
- self.global_step,
392
- sample_rate=self.sampling_rate,
393
- )
394
- self.logger.experiment.add_audio(
395
- f"sample-{idx}/wavs/recon",
396
- recon_audio[0, :audio_len],
397
- self.global_step,
398
- sample_rate=self.sampling_rate,
399
- )
400
-
401
- plt.close(image_mels)
402
-
403
- def encode(self, audios, audio_lengths):
404
- audios = audios.float()
405
-
406
- mels = self.encode_mel_transform(audios)
407
- mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
408
- mel_masks = sequence_mask(mel_lengths, mels.shape[2])
409
- mel_masks_float_conv = mel_masks[:, None, :].float()
410
- mels = mels * mel_masks_float_conv
411
-
412
- # Encode
413
- encoded_features = self.encoder(mels) * mel_masks_float_conv
414
- feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
415
-
416
- return self.quantizer.encode(encoded_features), feature_lengths
417
-
418
- def decode(self, indices, feature_lengths, return_audios=False):
419
- factor = math.prod(self.quantizer.downsample_factor)
420
- mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
421
- mel_masks_float_conv = mel_masks[:, None, :].float()
422
-
423
- z = self.quantizer.decode(indices) * mel_masks_float_conv
424
- z = (
425
- z
426
- + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
427
- :, :, None
428
- ]
429
- )
430
-
431
- gen_mel = (
432
- self.decoder(
433
- torch.randn_like(z) * mel_masks_float_conv,
434
- condition=z,
435
- )
436
- * mel_masks_float_conv
437
- )
438
-
439
- if return_audios:
440
- return self.vocoder(gen_mel)
441
-
442
- return gen_mel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/vqgan/modules/discriminator.py DELETED
@@ -1,44 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn.utils.parametrizations import weight_norm
4
-
5
-
6
- class Discriminator(nn.Module):
7
- def __init__(self):
8
- super().__init__()
9
-
10
- blocks = []
11
- convs = [
12
- (1, 64, (3, 9), 1, (1, 4)),
13
- (64, 128, (3, 9), (1, 2), (1, 4)),
14
- (128, 256, (3, 9), (1, 2), (1, 4)),
15
- (256, 512, (3, 9), (1, 2), (1, 4)),
16
- (512, 1024, (3, 3), 1, (1, 1)),
17
- (1024, 1, (3, 3), 1, (1, 1)),
18
- ]
19
-
20
- for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
21
- convs
22
- ):
23
- blocks.append(
24
- weight_norm(
25
- nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
26
- )
27
- )
28
-
29
- if idx != len(convs) - 1:
30
- blocks.append(nn.SiLU(inplace=True))
31
-
32
- self.blocks = nn.Sequential(*blocks)
33
-
34
- def forward(self, x):
35
- return self.blocks(x[:, None])[:, 0]
36
-
37
-
38
- if __name__ == "__main__":
39
- model = Discriminator()
40
- print(sum(p.numel() for p in model.parameters()) / 1_000_000)
41
- x = torch.randn(1, 128, 1024)
42
- y = model(x)
43
- print(y.shape)
44
- print(y)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/vqgan/modules/firefly.py CHANGED
@@ -1,596 +1,596 @@
1
- import math
2
- from functools import partial
3
- from math import prod
4
- from typing import Callable
5
-
6
- import torch
7
- import torch.nn.functional as F
8
- from torch import nn
9
- from torch.nn.utils.parametrizations import weight_norm
10
- from torch.nn.utils.parametrize import remove_parametrizations
11
- from torch.utils.checkpoint import checkpoint
12
-
13
-
14
- def sequence_mask(length, max_length=None):
15
- if max_length is None:
16
- max_length = length.max()
17
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
- return x.unsqueeze(0) < length.unsqueeze(1)
19
-
20
-
21
- def init_weights(m, mean=0.0, std=0.01):
22
- classname = m.__class__.__name__
23
- if classname.find("Conv1D") != -1:
24
- m.weight.data.normal_(mean, std)
25
-
26
-
27
- def get_padding(kernel_size, dilation=1):
28
- return (kernel_size * dilation - dilation) // 2
29
-
30
-
31
- def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
32
- """Remove padding from x, handling properly zero padding. Only for 1d!"""
33
- padding_left, padding_right = paddings
34
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
35
- assert (padding_left + padding_right) <= x.shape[-1]
36
- end = x.shape[-1] - padding_right
37
- return x[..., padding_left:end]
38
-
39
-
40
- def get_extra_padding_for_conv1d(
41
- x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
42
- ) -> int:
43
- """See `pad_for_conv1d`."""
44
- length = x.shape[-1]
45
- n_frames = (length - kernel_size + padding_total) / stride + 1
46
- ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
47
- return ideal_length - length
48
-
49
-
50
- def pad1d(
51
- x: torch.Tensor,
52
- paddings: tuple[int, int],
53
- mode: str = "zeros",
54
- value: float = 0.0,
55
- ):
56
- """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
57
- If this is the case, we insert extra 0 padding to the right
58
- before the reflection happen.
59
- """
60
- length = x.shape[-1]
61
- padding_left, padding_right = paddings
62
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
63
- if mode == "reflect":
64
- max_pad = max(padding_left, padding_right)
65
- extra_pad = 0
66
- if length <= max_pad:
67
- extra_pad = max_pad - length + 1
68
- x = F.pad(x, (0, extra_pad))
69
- padded = F.pad(x, paddings, mode, value)
70
- end = padded.shape[-1] - extra_pad
71
- return padded[..., :end]
72
- else:
73
- return F.pad(x, paddings, mode, value)
74
-
75
-
76
- class FishConvNet(nn.Module):
77
- def __init__(
78
- self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
79
- ):
80
- super(FishConvNet, self).__init__()
81
- self.conv = nn.Conv1d(
82
- in_channels,
83
- out_channels,
84
- kernel_size,
85
- stride=stride,
86
- dilation=dilation,
87
- groups=groups,
88
- )
89
- self.stride = stride
90
- self.kernel_size = (kernel_size - 1) * dilation + 1
91
- self.dilation = dilation
92
-
93
- def forward(self, x):
94
- pad = self.kernel_size - self.stride
95
- extra_padding = get_extra_padding_for_conv1d(
96
- x, self.kernel_size, self.stride, pad
97
- )
98
- x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
99
- return self.conv(x).contiguous()
100
-
101
- def weight_norm(self, name="weight", dim=0):
102
- self.conv = weight_norm(self.conv, name=name, dim=dim)
103
- return self
104
-
105
- def remove_weight_norm(self):
106
- self.conv = remove_parametrizations(self.conv)
107
- return self
108
-
109
-
110
- class FishTransConvNet(nn.Module):
111
- def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
112
- super(FishTransConvNet, self).__init__()
113
- self.conv = nn.ConvTranspose1d(
114
- in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
115
- )
116
- self.stride = stride
117
- self.kernel_size = kernel_size
118
-
119
- def forward(self, x):
120
- x = self.conv(x)
121
- pad = self.kernel_size - self.stride
122
- padding_right = math.ceil(pad)
123
- padding_left = pad - padding_right
124
- x = unpad1d(x, (padding_left, padding_right))
125
- return x.contiguous()
126
-
127
- def weight_norm(self, name="weight", dim=0):
128
- self.conv = weight_norm(self.conv, name=name, dim=dim)
129
- return self
130
-
131
- def remove_weight_norm(self):
132
- self.conv = remove_parametrizations(self.conv)
133
- return self
134
-
135
-
136
- class ResBlock1(torch.nn.Module):
137
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
138
- super().__init__()
139
-
140
- self.convs1 = nn.ModuleList(
141
- [
142
- FishConvNet(
143
- channels, channels, kernel_size, stride=1, dilation=dilation[0]
144
- ).weight_norm(),
145
- FishConvNet(
146
- channels, channels, kernel_size, stride=1, dilation=dilation[1]
147
- ).weight_norm(),
148
- FishConvNet(
149
- channels, channels, kernel_size, stride=1, dilation=dilation[2]
150
- ).weight_norm(),
151
- ]
152
- )
153
- self.convs1.apply(init_weights)
154
-
155
- self.convs2 = nn.ModuleList(
156
- [
157
- FishConvNet(
158
- channels, channels, kernel_size, stride=1, dilation=dilation[0]
159
- ).weight_norm(),
160
- FishConvNet(
161
- channels, channels, kernel_size, stride=1, dilation=dilation[1]
162
- ).weight_norm(),
163
- FishConvNet(
164
- channels, channels, kernel_size, stride=1, dilation=dilation[2]
165
- ).weight_norm(),
166
- ]
167
- )
168
- self.convs2.apply(init_weights)
169
-
170
- def forward(self, x):
171
- for c1, c2 in zip(self.convs1, self.convs2):
172
- xt = F.silu(x)
173
- xt = c1(xt)
174
- xt = F.silu(xt)
175
- xt = c2(xt)
176
- x = xt + x
177
- return x
178
-
179
- def remove_parametrizations(self):
180
- for conv in self.convs1:
181
- remove_parametrizations(conv, tensor_name="weight")
182
- for conv in self.convs2:
183
- remove_parametrizations(conv, tensor_name="weight")
184
-
185
-
186
- class ParallelBlock(nn.Module):
187
- def __init__(
188
- self,
189
- channels: int,
190
- kernel_sizes: tuple[int] = (3, 7, 11),
191
- dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
192
- ):
193
- super().__init__()
194
-
195
- assert len(kernel_sizes) == len(dilation_sizes)
196
-
197
- self.blocks = nn.ModuleList()
198
- for k, d in zip(kernel_sizes, dilation_sizes):
199
- self.blocks.append(ResBlock1(channels, k, d))
200
-
201
- def forward(self, x):
202
- return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
203
-
204
- def remove_parametrizations(self):
205
- for block in self.blocks:
206
- block.remove_parametrizations()
207
-
208
-
209
- class HiFiGANGenerator(nn.Module):
210
- def __init__(
211
- self,
212
- *,
213
- hop_length: int = 512,
214
- upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
215
- upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
216
- resblock_kernel_sizes: tuple[int] = (3, 7, 11),
217
- resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
218
- num_mels: int = 128,
219
- upsample_initial_channel: int = 512,
220
- pre_conv_kernel_size: int = 7,
221
- post_conv_kernel_size: int = 7,
222
- post_activation: Callable = partial(nn.SiLU, inplace=True),
223
- ):
224
- super().__init__()
225
-
226
- assert (
227
- prod(upsample_rates) == hop_length
228
- ), f"hop_length must be {prod(upsample_rates)}"
229
-
230
- self.conv_pre = FishConvNet(
231
- num_mels,
232
- upsample_initial_channel,
233
- pre_conv_kernel_size,
234
- stride=1,
235
- ).weight_norm()
236
-
237
- self.num_upsamples = len(upsample_rates)
238
- self.num_kernels = len(resblock_kernel_sizes)
239
-
240
- self.noise_convs = nn.ModuleList()
241
- self.ups = nn.ModuleList()
242
-
243
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
244
- self.ups.append(
245
- FishTransConvNet(
246
- upsample_initial_channel // (2**i),
247
- upsample_initial_channel // (2 ** (i + 1)),
248
- k,
249
- stride=u,
250
- ).weight_norm()
251
- )
252
-
253
- self.resblocks = nn.ModuleList()
254
- for i in range(len(self.ups)):
255
- ch = upsample_initial_channel // (2 ** (i + 1))
256
- self.resblocks.append(
257
- ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
258
- )
259
-
260
- self.activation_post = post_activation()
261
- self.conv_post = FishConvNet(
262
- ch, 1, post_conv_kernel_size, stride=1
263
- ).weight_norm()
264
- self.ups.apply(init_weights)
265
- self.conv_post.apply(init_weights)
266
-
267
- def forward(self, x):
268
- x = self.conv_pre(x)
269
-
270
- for i in range(self.num_upsamples):
271
- x = F.silu(x, inplace=True)
272
- x = self.ups[i](x)
273
-
274
- if self.training and self.checkpointing:
275
- x = checkpoint(
276
- self.resblocks[i],
277
- x,
278
- use_reentrant=False,
279
- )
280
- else:
281
- x = self.resblocks[i](x)
282
-
283
- x = self.activation_post(x)
284
- x = self.conv_post(x)
285
- x = torch.tanh(x)
286
-
287
- return x
288
-
289
- def remove_parametrizations(self):
290
- for up in self.ups:
291
- remove_parametrizations(up, tensor_name="weight")
292
- for block in self.resblocks:
293
- block.remove_parametrizations()
294
- remove_parametrizations(self.conv_pre, tensor_name="weight")
295
- remove_parametrizations(self.conv_post, tensor_name="weight")
296
-
297
-
298
- # DropPath copied from timm library
299
- def drop_path(
300
- x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
301
- ):
302
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
303
-
304
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
305
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
306
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
307
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
308
- 'survival rate' as the argument.
309
-
310
- """ # noqa: E501
311
-
312
- if drop_prob == 0.0 or not training:
313
- return x
314
- keep_prob = 1 - drop_prob
315
- shape = (x.shape[0],) + (1,) * (
316
- x.ndim - 1
317
- ) # work with diff dim tensors, not just 2D ConvNets
318
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
319
- if keep_prob > 0.0 and scale_by_keep:
320
- random_tensor.div_(keep_prob)
321
- return x * random_tensor
322
-
323
-
324
- class DropPath(nn.Module):
325
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
326
-
327
- def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
328
- super(DropPath, self).__init__()
329
- self.drop_prob = drop_prob
330
- self.scale_by_keep = scale_by_keep
331
-
332
- def forward(self, x):
333
- return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
334
-
335
- def extra_repr(self):
336
- return f"drop_prob={round(self.drop_prob,3):0.3f}"
337
-
338
-
339
- class LayerNorm(nn.Module):
340
- r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
341
- The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
342
- shape (batch_size, height, width, channels) while channels_first corresponds to inputs
343
- with shape (batch_size, channels, height, width).
344
- """ # noqa: E501
345
-
346
- def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
347
- super().__init__()
348
- self.weight = nn.Parameter(torch.ones(normalized_shape))
349
- self.bias = nn.Parameter(torch.zeros(normalized_shape))
350
- self.eps = eps
351
- self.data_format = data_format
352
- if self.data_format not in ["channels_last", "channels_first"]:
353
- raise NotImplementedError
354
- self.normalized_shape = (normalized_shape,)
355
-
356
- def forward(self, x):
357
- if self.data_format == "channels_last":
358
- return F.layer_norm(
359
- x, self.normalized_shape, self.weight, self.bias, self.eps
360
- )
361
- elif self.data_format == "channels_first":
362
- u = x.mean(1, keepdim=True)
363
- s = (x - u).pow(2).mean(1, keepdim=True)
364
- x = (x - u) / torch.sqrt(s + self.eps)
365
- x = self.weight[:, None] * x + self.bias[:, None]
366
- return x
367
-
368
-
369
- # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
370
- class ConvNeXtBlock(nn.Module):
371
- r"""ConvNeXt Block. There are two equivalent implementations:
372
- (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
373
- (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
374
- We use (2) as we find it slightly faster in PyTorch
375
-
376
- Args:
377
- dim (int): Number of input channels.
378
- drop_path (float): Stochastic depth rate. Default: 0.0
379
- layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
380
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
381
- kernel_size (int): Kernel size for depthwise conv. Default: 7.
382
- dilation (int): Dilation for depthwise conv. Default: 1.
383
- """ # noqa: E501
384
-
385
- def __init__(
386
- self,
387
- dim: int,
388
- drop_path: float = 0.0,
389
- layer_scale_init_value: float = 1e-6,
390
- mlp_ratio: float = 4.0,
391
- kernel_size: int = 7,
392
- dilation: int = 1,
393
- ):
394
- super().__init__()
395
-
396
- self.dwconv = FishConvNet(
397
- dim,
398
- dim,
399
- kernel_size=kernel_size,
400
- # padding=int(dilation * (kernel_size - 1) / 2),
401
- groups=dim,
402
- ) # depthwise conv
403
- self.norm = LayerNorm(dim, eps=1e-6)
404
- self.pwconv1 = nn.Linear(
405
- dim, int(mlp_ratio * dim)
406
- ) # pointwise/1x1 convs, implemented with linear layers
407
- self.act = nn.GELU()
408
- self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
409
- self.gamma = (
410
- nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
411
- if layer_scale_init_value > 0
412
- else None
413
- )
414
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
415
-
416
- def forward(self, x, apply_residual: bool = True):
417
- input = x
418
-
419
- x = self.dwconv(x)
420
- x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
421
- x = self.norm(x)
422
- x = self.pwconv1(x)
423
- x = self.act(x)
424
- x = self.pwconv2(x)
425
-
426
- if self.gamma is not None:
427
- x = self.gamma * x
428
-
429
- x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
430
- x = self.drop_path(x)
431
-
432
- if apply_residual:
433
- x = input + x
434
-
435
- return x
436
-
437
-
438
- class ConvNeXtEncoder(nn.Module):
439
- def __init__(
440
- self,
441
- input_channels: int = 3,
442
- depths: list[int] = [3, 3, 9, 3],
443
- dims: list[int] = [96, 192, 384, 768],
444
- drop_path_rate: float = 0.0,
445
- layer_scale_init_value: float = 1e-6,
446
- kernel_size: int = 7,
447
- ):
448
- super().__init__()
449
- assert len(depths) == len(dims)
450
-
451
- self.downsample_layers = nn.ModuleList()
452
- stem = nn.Sequential(
453
- FishConvNet(
454
- input_channels,
455
- dims[0],
456
- kernel_size=7,
457
- # padding=3,
458
- # padding_mode="replicate",
459
- # padding_mode="zeros",
460
- ),
461
- LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
462
- )
463
- self.downsample_layers.append(stem)
464
-
465
- for i in range(len(depths) - 1):
466
- mid_layer = nn.Sequential(
467
- LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
468
- nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
469
- )
470
- self.downsample_layers.append(mid_layer)
471
-
472
- self.stages = nn.ModuleList()
473
- dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
474
-
475
- cur = 0
476
- for i in range(len(depths)):
477
- stage = nn.Sequential(
478
- *[
479
- ConvNeXtBlock(
480
- dim=dims[i],
481
- drop_path=dp_rates[cur + j],
482
- layer_scale_init_value=layer_scale_init_value,
483
- kernel_size=kernel_size,
484
- )
485
- for j in range(depths[i])
486
- ]
487
- )
488
- self.stages.append(stage)
489
- cur += depths[i]
490
-
491
- self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
492
- self.apply(self._init_weights)
493
-
494
- def _init_weights(self, m):
495
- if isinstance(m, (nn.Conv1d, nn.Linear)):
496
- nn.init.trunc_normal_(m.weight, std=0.02)
497
- nn.init.constant_(m.bias, 0)
498
-
499
- def forward(
500
- self,
501
- x: torch.Tensor,
502
- ) -> torch.Tensor:
503
- for i in range(len(self.downsample_layers)):
504
- x = self.downsample_layers[i](x)
505
- x = self.stages[i](x)
506
-
507
- return self.norm(x)
508
-
509
-
510
- class FireflyArchitecture(nn.Module):
511
- def __init__(
512
- self,
513
- backbone: nn.Module,
514
- head: nn.Module,
515
- quantizer: nn.Module,
516
- spec_transform: nn.Module,
517
- ):
518
- super().__init__()
519
-
520
- self.backbone = backbone
521
- self.head = head
522
- self.quantizer = quantizer
523
- self.spec_transform = spec_transform
524
- self.downsample_factor = math.prod(self.quantizer.downsample_factor)
525
-
526
- def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
527
- if self.spec_transform is not None:
528
- x = self.spec_transform(x)
529
-
530
- x = self.backbone(x)
531
- if mask is not None:
532
- x = x * mask
533
-
534
- if self.quantizer is not None:
535
- vq_result = self.quantizer(x)
536
- x = vq_result.z
537
-
538
- if mask is not None:
539
- x = x * mask
540
-
541
- x = self.head(x, template=template)
542
-
543
- if x.ndim == 2:
544
- x = x[:, None, :]
545
-
546
- if self.vq is not None:
547
- return x, vq_result
548
-
549
- return x
550
-
551
- def encode(self, audios, audio_lengths):
552
- audios = audios.float()
553
-
554
- mels = self.spec_transform(audios)
555
- mel_lengths = audio_lengths // self.spec_transform.hop_length
556
- mel_masks = sequence_mask(mel_lengths, mels.shape[2])
557
- mel_masks_float_conv = mel_masks[:, None, :].float()
558
- mels = mels * mel_masks_float_conv
559
-
560
- # Encode
561
- encoded_features = self.backbone(mels) * mel_masks_float_conv
562
- feature_lengths = mel_lengths // self.downsample_factor
563
-
564
- return self.quantizer.encode(encoded_features), feature_lengths
565
-
566
- def decode(self, indices, feature_lengths) -> torch.Tensor:
567
- mel_masks = sequence_mask(
568
- feature_lengths * self.downsample_factor,
569
- indices.shape[2] * self.downsample_factor,
570
- )
571
- mel_masks_float_conv = mel_masks[:, None, :].float()
572
- audio_lengths = (
573
- feature_lengths * self.downsample_factor * self.spec_transform.hop_length
574
- )
575
-
576
- audio_masks = sequence_mask(
577
- audio_lengths,
578
- indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
579
- )
580
- audio_masks_float_conv = audio_masks[:, None, :].float()
581
-
582
- z = self.quantizer.decode(indices) * mel_masks_float_conv
583
- x = self.head(z) * audio_masks_float_conv
584
-
585
- return x, audio_lengths
586
-
587
- def remove_parametrizations(self):
588
- if hasattr(self.backbone, "remove_parametrizations"):
589
- self.backbone.remove_parametrizations()
590
-
591
- if hasattr(self.head, "remove_parametrizations"):
592
- self.head.remove_parametrizations()
593
-
594
- @property
595
- def device(self):
596
- return next(self.parameters()).device
 
1
+ import math
2
+ from functools import partial
3
+ from math import prod
4
+ from typing import Callable
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from torch.nn.utils.parametrizations import weight_norm
10
+ from torch.nn.utils.parametrize import remove_parametrizations
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+
14
+ def sequence_mask(length, max_length=None):
15
+ if max_length is None:
16
+ max_length = length.max()
17
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
+ return x.unsqueeze(0) < length.unsqueeze(1)
19
+
20
+
21
+ def init_weights(m, mean=0.0, std=0.01):
22
+ classname = m.__class__.__name__
23
+ if classname.find("Conv1D") != -1:
24
+ m.weight.data.normal_(mean, std)
25
+
26
+
27
+ def get_padding(kernel_size, dilation=1):
28
+ return (kernel_size * dilation - dilation) // 2
29
+
30
+
31
+ def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
32
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
33
+ padding_left, padding_right = paddings
34
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
35
+ assert (padding_left + padding_right) <= x.shape[-1]
36
+ end = x.shape[-1] - padding_right
37
+ return x[..., padding_left:end]
38
+
39
+
40
+ def get_extra_padding_for_conv1d(
41
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
42
+ ) -> int:
43
+ """See `pad_for_conv1d`."""
44
+ length = x.shape[-1]
45
+ n_frames = (length - kernel_size + padding_total) / stride + 1
46
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
47
+ return ideal_length - length
48
+
49
+
50
+ def pad1d(
51
+ x: torch.Tensor,
52
+ paddings: tuple[int, int],
53
+ mode: str = "zeros",
54
+ value: float = 0.0,
55
+ ):
56
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
57
+ If this is the case, we insert extra 0 padding to the right
58
+ before the reflection happen.
59
+ """
60
+ length = x.shape[-1]
61
+ padding_left, padding_right = paddings
62
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
63
+ if mode == "reflect":
64
+ max_pad = max(padding_left, padding_right)
65
+ extra_pad = 0
66
+ if length <= max_pad:
67
+ extra_pad = max_pad - length + 1
68
+ x = F.pad(x, (0, extra_pad))
69
+ padded = F.pad(x, paddings, mode, value)
70
+ end = padded.shape[-1] - extra_pad
71
+ return padded[..., :end]
72
+ else:
73
+ return F.pad(x, paddings, mode, value)
74
+
75
+
76
+ class FishConvNet(nn.Module):
77
+ def __init__(
78
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
79
+ ):
80
+ super(FishConvNet, self).__init__()
81
+ self.conv = nn.Conv1d(
82
+ in_channels,
83
+ out_channels,
84
+ kernel_size,
85
+ stride=stride,
86
+ dilation=dilation,
87
+ groups=groups,
88
+ )
89
+ self.stride = stride
90
+ self.kernel_size = (kernel_size - 1) * dilation + 1
91
+ self.dilation = dilation
92
+
93
+ def forward(self, x):
94
+ pad = self.kernel_size - self.stride
95
+ extra_padding = get_extra_padding_for_conv1d(
96
+ x, self.kernel_size, self.stride, pad
97
+ )
98
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
99
+ return self.conv(x).contiguous()
100
+
101
+ def weight_norm(self, name="weight", dim=0):
102
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
103
+ return self
104
+
105
+ def remove_parametrizations(self, name="weight"):
106
+ self.conv = remove_parametrizations(self.conv, name)
107
+ return self
108
+
109
+
110
+ class FishTransConvNet(nn.Module):
111
+ def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
112
+ super(FishTransConvNet, self).__init__()
113
+ self.conv = nn.ConvTranspose1d(
114
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
115
+ )
116
+ self.stride = stride
117
+ self.kernel_size = kernel_size
118
+
119
+ def forward(self, x):
120
+ x = self.conv(x)
121
+ pad = self.kernel_size - self.stride
122
+ padding_right = math.ceil(pad)
123
+ padding_left = pad - padding_right
124
+ x = unpad1d(x, (padding_left, padding_right))
125
+ return x.contiguous()
126
+
127
+ def weight_norm(self, name="weight", dim=0):
128
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
129
+ return self
130
+
131
+ def remove_parametrizations(self, name="weight"):
132
+ self.conv = remove_parametrizations(self.conv, name)
133
+ return self
134
+
135
+
136
+ class ResBlock1(torch.nn.Module):
137
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
138
+ super().__init__()
139
+
140
+ self.convs1 = nn.ModuleList(
141
+ [
142
+ FishConvNet(
143
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
144
+ ).weight_norm(),
145
+ FishConvNet(
146
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
147
+ ).weight_norm(),
148
+ FishConvNet(
149
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
150
+ ).weight_norm(),
151
+ ]
152
+ )
153
+ self.convs1.apply(init_weights)
154
+
155
+ self.convs2 = nn.ModuleList(
156
+ [
157
+ FishConvNet(
158
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
159
+ ).weight_norm(),
160
+ FishConvNet(
161
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
162
+ ).weight_norm(),
163
+ FishConvNet(
164
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
165
+ ).weight_norm(),
166
+ ]
167
+ )
168
+ self.convs2.apply(init_weights)
169
+
170
+ def forward(self, x):
171
+ for c1, c2 in zip(self.convs1, self.convs2):
172
+ xt = F.silu(x)
173
+ xt = c1(xt)
174
+ xt = F.silu(xt)
175
+ xt = c2(xt)
176
+ x = xt + x
177
+ return x
178
+
179
+ def remove_parametrizations(self):
180
+ for conv in self.convs1:
181
+ conv.remove_parametrizations()
182
+ for conv in self.convs2:
183
+ conv.remove_parametrizations()
184
+
185
+
186
+ class ParallelBlock(nn.Module):
187
+ def __init__(
188
+ self,
189
+ channels: int,
190
+ kernel_sizes: tuple[int] = (3, 7, 11),
191
+ dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
192
+ ):
193
+ super().__init__()
194
+
195
+ assert len(kernel_sizes) == len(dilation_sizes)
196
+
197
+ self.blocks = nn.ModuleList()
198
+ for k, d in zip(kernel_sizes, dilation_sizes):
199
+ self.blocks.append(ResBlock1(channels, k, d))
200
+
201
+ def forward(self, x):
202
+ return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
203
+
204
+ def remove_parametrizations(self):
205
+ for block in self.blocks:
206
+ block.remove_parametrizations()
207
+
208
+
209
+ class HiFiGANGenerator(nn.Module):
210
+ def __init__(
211
+ self,
212
+ *,
213
+ hop_length: int = 512,
214
+ upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
215
+ upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
216
+ resblock_kernel_sizes: tuple[int] = (3, 7, 11),
217
+ resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
218
+ num_mels: int = 128,
219
+ upsample_initial_channel: int = 512,
220
+ pre_conv_kernel_size: int = 7,
221
+ post_conv_kernel_size: int = 7,
222
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
223
+ ):
224
+ super().__init__()
225
+
226
+ assert (
227
+ prod(upsample_rates) == hop_length
228
+ ), f"hop_length must be {prod(upsample_rates)}"
229
+
230
+ self.conv_pre = FishConvNet(
231
+ num_mels,
232
+ upsample_initial_channel,
233
+ pre_conv_kernel_size,
234
+ stride=1,
235
+ ).weight_norm()
236
+
237
+ self.num_upsamples = len(upsample_rates)
238
+ self.num_kernels = len(resblock_kernel_sizes)
239
+
240
+ self.noise_convs = nn.ModuleList()
241
+ self.ups = nn.ModuleList()
242
+
243
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
244
+ self.ups.append(
245
+ FishTransConvNet(
246
+ upsample_initial_channel // (2**i),
247
+ upsample_initial_channel // (2 ** (i + 1)),
248
+ k,
249
+ stride=u,
250
+ ).weight_norm()
251
+ )
252
+
253
+ self.resblocks = nn.ModuleList()
254
+ for i in range(len(self.ups)):
255
+ ch = upsample_initial_channel // (2 ** (i + 1))
256
+ self.resblocks.append(
257
+ ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
258
+ )
259
+
260
+ self.activation_post = post_activation()
261
+ self.conv_post = FishConvNet(
262
+ ch, 1, post_conv_kernel_size, stride=1
263
+ ).weight_norm()
264
+ self.ups.apply(init_weights)
265
+ self.conv_post.apply(init_weights)
266
+
267
+ def forward(self, x):
268
+ x = self.conv_pre(x)
269
+
270
+ for i in range(self.num_upsamples):
271
+ x = F.silu(x, inplace=True)
272
+ x = self.ups[i](x)
273
+
274
+ if self.training and self.checkpointing:
275
+ x = checkpoint(
276
+ self.resblocks[i],
277
+ x,
278
+ use_reentrant=False,
279
+ )
280
+ else:
281
+ x = self.resblocks[i](x)
282
+
283
+ x = self.activation_post(x)
284
+ x = self.conv_post(x)
285
+ x = torch.tanh(x)
286
+
287
+ return x
288
+
289
+ def remove_parametrizations(self):
290
+ for up in self.ups:
291
+ up.remove_parametrizations()
292
+ for block in self.resblocks:
293
+ block.remove_parametrizations()
294
+ self.conv_pre.remove_parametrizations()
295
+ self.conv_post.remove_parametrizations()
296
+
297
+
298
+ # DropPath copied from timm library
299
+ def drop_path(
300
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
301
+ ):
302
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
303
+
304
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
305
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
306
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
307
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
308
+ 'survival rate' as the argument.
309
+
310
+ """ # noqa: E501
311
+
312
+ if drop_prob == 0.0 or not training:
313
+ return x
314
+ keep_prob = 1 - drop_prob
315
+ shape = (x.shape[0],) + (1,) * (
316
+ x.ndim - 1
317
+ ) # work with diff dim tensors, not just 2D ConvNets
318
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
319
+ if keep_prob > 0.0 and scale_by_keep:
320
+ random_tensor.div_(keep_prob)
321
+ return x * random_tensor
322
+
323
+
324
+ class DropPath(nn.Module):
325
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
326
+
327
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
328
+ super(DropPath, self).__init__()
329
+ self.drop_prob = drop_prob
330
+ self.scale_by_keep = scale_by_keep
331
+
332
+ def forward(self, x):
333
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
334
+
335
+ def extra_repr(self):
336
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
337
+
338
+
339
+ class LayerNorm(nn.Module):
340
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
341
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
342
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
343
+ with shape (batch_size, channels, height, width).
344
+ """ # noqa: E501
345
+
346
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
347
+ super().__init__()
348
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
349
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
350
+ self.eps = eps
351
+ self.data_format = data_format
352
+ if self.data_format not in ["channels_last", "channels_first"]:
353
+ raise NotImplementedError
354
+ self.normalized_shape = (normalized_shape,)
355
+
356
+ def forward(self, x):
357
+ if self.data_format == "channels_last":
358
+ return F.layer_norm(
359
+ x, self.normalized_shape, self.weight, self.bias, self.eps
360
+ )
361
+ elif self.data_format == "channels_first":
362
+ u = x.mean(1, keepdim=True)
363
+ s = (x - u).pow(2).mean(1, keepdim=True)
364
+ x = (x - u) / torch.sqrt(s + self.eps)
365
+ x = self.weight[:, None] * x + self.bias[:, None]
366
+ return x
367
+
368
+
369
+ # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
370
+ class ConvNeXtBlock(nn.Module):
371
+ r"""ConvNeXt Block. There are two equivalent implementations:
372
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
373
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
374
+ We use (2) as we find it slightly faster in PyTorch
375
+
376
+ Args:
377
+ dim (int): Number of input channels.
378
+ drop_path (float): Stochastic depth rate. Default: 0.0
379
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
380
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
381
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
382
+ dilation (int): Dilation for depthwise conv. Default: 1.
383
+ """ # noqa: E501
384
+
385
+ def __init__(
386
+ self,
387
+ dim: int,
388
+ drop_path: float = 0.0,
389
+ layer_scale_init_value: float = 1e-6,
390
+ mlp_ratio: float = 4.0,
391
+ kernel_size: int = 7,
392
+ dilation: int = 1,
393
+ ):
394
+ super().__init__()
395
+
396
+ self.dwconv = FishConvNet(
397
+ dim,
398
+ dim,
399
+ kernel_size=kernel_size,
400
+ # padding=int(dilation * (kernel_size - 1) / 2),
401
+ groups=dim,
402
+ ) # depthwise conv
403
+ self.norm = LayerNorm(dim, eps=1e-6)
404
+ self.pwconv1 = nn.Linear(
405
+ dim, int(mlp_ratio * dim)
406
+ ) # pointwise/1x1 convs, implemented with linear layers
407
+ self.act = nn.GELU()
408
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
409
+ self.gamma = (
410
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
411
+ if layer_scale_init_value > 0
412
+ else None
413
+ )
414
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
415
+
416
+ def forward(self, x, apply_residual: bool = True):
417
+ input = x
418
+
419
+ x = self.dwconv(x)
420
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
421
+ x = self.norm(x)
422
+ x = self.pwconv1(x)
423
+ x = self.act(x)
424
+ x = self.pwconv2(x)
425
+
426
+ if self.gamma is not None:
427
+ x = self.gamma * x
428
+
429
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
430
+ x = self.drop_path(x)
431
+
432
+ if apply_residual:
433
+ x = input + x
434
+
435
+ return x
436
+
437
+
438
+ class ConvNeXtEncoder(nn.Module):
439
+ def __init__(
440
+ self,
441
+ input_channels: int = 3,
442
+ depths: list[int] = [3, 3, 9, 3],
443
+ dims: list[int] = [96, 192, 384, 768],
444
+ drop_path_rate: float = 0.0,
445
+ layer_scale_init_value: float = 1e-6,
446
+ kernel_size: int = 7,
447
+ ):
448
+ super().__init__()
449
+ assert len(depths) == len(dims)
450
+
451
+ self.downsample_layers = nn.ModuleList()
452
+ stem = nn.Sequential(
453
+ FishConvNet(
454
+ input_channels,
455
+ dims[0],
456
+ kernel_size=7,
457
+ # padding=3,
458
+ # padding_mode="replicate",
459
+ # padding_mode="zeros",
460
+ ),
461
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
462
+ )
463
+ self.downsample_layers.append(stem)
464
+
465
+ for i in range(len(depths) - 1):
466
+ mid_layer = nn.Sequential(
467
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
468
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
469
+ )
470
+ self.downsample_layers.append(mid_layer)
471
+
472
+ self.stages = nn.ModuleList()
473
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
474
+
475
+ cur = 0
476
+ for i in range(len(depths)):
477
+ stage = nn.Sequential(
478
+ *[
479
+ ConvNeXtBlock(
480
+ dim=dims[i],
481
+ drop_path=dp_rates[cur + j],
482
+ layer_scale_init_value=layer_scale_init_value,
483
+ kernel_size=kernel_size,
484
+ )
485
+ for j in range(depths[i])
486
+ ]
487
+ )
488
+ self.stages.append(stage)
489
+ cur += depths[i]
490
+
491
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
492
+ self.apply(self._init_weights)
493
+
494
+ def _init_weights(self, m):
495
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
496
+ nn.init.trunc_normal_(m.weight, std=0.02)
497
+ nn.init.constant_(m.bias, 0)
498
+
499
+ def forward(
500
+ self,
501
+ x: torch.Tensor,
502
+ ) -> torch.Tensor:
503
+ for i in range(len(self.downsample_layers)):
504
+ x = self.downsample_layers[i](x)
505
+ x = self.stages[i](x)
506
+
507
+ return self.norm(x)
508
+
509
+
510
+ class FireflyArchitecture(nn.Module):
511
+ def __init__(
512
+ self,
513
+ backbone: nn.Module,
514
+ head: nn.Module,
515
+ quantizer: nn.Module,
516
+ spec_transform: nn.Module,
517
+ ):
518
+ super().__init__()
519
+
520
+ self.backbone = backbone
521
+ self.head = head
522
+ self.quantizer = quantizer
523
+ self.spec_transform = spec_transform
524
+ self.downsample_factor = math.prod(self.quantizer.downsample_factor)
525
+
526
+ def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
527
+ if self.spec_transform is not None:
528
+ x = self.spec_transform(x)
529
+
530
+ x = self.backbone(x)
531
+ if mask is not None:
532
+ x = x * mask
533
+
534
+ if self.quantizer is not None:
535
+ vq_result = self.quantizer(x)
536
+ x = vq_result.z
537
+
538
+ if mask is not None:
539
+ x = x * mask
540
+
541
+ x = self.head(x, template=template)
542
+
543
+ if x.ndim == 2:
544
+ x = x[:, None, :]
545
+
546
+ if self.vq is not None:
547
+ return x, vq_result
548
+
549
+ return x
550
+
551
+ def encode(self, audios, audio_lengths):
552
+ audios = audios.float()
553
+
554
+ mels = self.spec_transform(audios)
555
+ mel_lengths = audio_lengths // self.spec_transform.hop_length
556
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
557
+ mel_masks_float_conv = mel_masks[:, None, :].float()
558
+ mels = mels * mel_masks_float_conv
559
+
560
+ # Encode
561
+ encoded_features = self.backbone(mels) * mel_masks_float_conv
562
+ feature_lengths = mel_lengths // self.downsample_factor
563
+
564
+ return self.quantizer.encode(encoded_features), feature_lengths
565
+
566
+ def decode(self, indices, feature_lengths) -> torch.Tensor:
567
+ mel_masks = sequence_mask(
568
+ feature_lengths * self.downsample_factor,
569
+ indices.shape[2] * self.downsample_factor,
570
+ )
571
+ mel_masks_float_conv = mel_masks[:, None, :].float()
572
+ audio_lengths = (
573
+ feature_lengths * self.downsample_factor * self.spec_transform.hop_length
574
+ )
575
+
576
+ audio_masks = sequence_mask(
577
+ audio_lengths,
578
+ indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
579
+ )
580
+ audio_masks_float_conv = audio_masks[:, None, :].float()
581
+
582
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
583
+ x = self.head(z) * audio_masks_float_conv
584
+
585
+ return x, audio_lengths
586
+
587
+ def remove_parametrizations(self):
588
+ if hasattr(self.backbone, "remove_parametrizations"):
589
+ self.backbone.remove_parametrizations()
590
+
591
+ if hasattr(self.head, "remove_parametrizations"):
592
+ self.head.remove_parametrizations()
593
+
594
+ @property
595
+ def device(self):
596
+ return next(self.parameters()).device
fish_speech/models/vqgan/modules/fsq.py CHANGED
@@ -1,116 +1,116 @@
1
- from dataclasses import dataclass
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from einops import rearrange
7
- from vector_quantize_pytorch import GroupedResidualFSQ
8
-
9
- from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
10
-
11
-
12
- @dataclass
13
- class FSQResult:
14
- z: torch.Tensor
15
- codes: torch.Tensor
16
- latents: torch.Tensor
17
-
18
-
19
- class DownsampleFiniteScalarQuantize(nn.Module):
20
- def __init__(
21
- self,
22
- input_dim: int = 512,
23
- n_codebooks: int = 9,
24
- n_groups: int = 1,
25
- levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26
- downsample_factor: tuple[int] = (2, 2),
27
- downsample_dims: tuple[int] | None = None,
28
- ):
29
- super().__init__()
30
-
31
- if downsample_dims is None:
32
- downsample_dims = [input_dim for _ in range(len(downsample_factor))]
33
-
34
- all_dims = (input_dim,) + tuple(downsample_dims)
35
-
36
- self.residual_fsq = GroupedResidualFSQ(
37
- dim=all_dims[-1],
38
- levels=levels,
39
- num_quantizers=n_codebooks,
40
- groups=n_groups,
41
- )
42
-
43
- self.downsample_factor = downsample_factor
44
- self.downsample_dims = downsample_dims
45
-
46
- self.downsample = nn.Sequential(
47
- *[
48
- nn.Sequential(
49
- FishConvNet(
50
- all_dims[idx],
51
- all_dims[idx + 1],
52
- kernel_size=factor,
53
- stride=factor,
54
- ),
55
- ConvNeXtBlock(dim=all_dims[idx + 1]),
56
- )
57
- for idx, factor in enumerate(downsample_factor)
58
- ]
59
- )
60
-
61
- self.upsample = nn.Sequential(
62
- *[
63
- nn.Sequential(
64
- FishTransConvNet(
65
- all_dims[idx + 1],
66
- all_dims[idx],
67
- kernel_size=factor,
68
- stride=factor,
69
- ),
70
- ConvNeXtBlock(dim=all_dims[idx]),
71
- )
72
- for idx, factor in reversed(list(enumerate(downsample_factor)))
73
- ]
74
- )
75
-
76
- self.apply(self._init_weights)
77
-
78
- def _init_weights(self, m):
79
- if isinstance(m, (nn.Conv1d, nn.Linear)):
80
- nn.init.trunc_normal_(m.weight, std=0.02)
81
- nn.init.constant_(m.bias, 0)
82
-
83
- def forward(self, z) -> FSQResult:
84
- original_shape = z.shape
85
- z = self.downsample(z)
86
- quantized, indices = self.residual_fsq(z.mT)
87
- result = FSQResult(
88
- z=quantized.mT,
89
- codes=indices.mT,
90
- latents=z,
91
- )
92
- result.z = self.upsample(result.z)
93
-
94
- # Pad or crop z to match original shape
95
- diff = original_shape[-1] - result.z.shape[-1]
96
- left = diff // 2
97
- right = diff - left
98
-
99
- if diff > 0:
100
- result.z = F.pad(result.z, (left, right))
101
- elif diff < 0:
102
- result.z = result.z[..., left:-right]
103
-
104
- return result
105
-
106
- def encode(self, z):
107
- z = self.downsample(z)
108
- _, indices = self.residual_fsq(z.mT)
109
- indices = rearrange(indices, "g b l r -> b (g r) l")
110
- return indices
111
-
112
- def decode(self, indices: torch.Tensor):
113
- indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
114
- z_q = self.residual_fsq.get_output_from_indices(indices)
115
- z_q = self.upsample(z_q.mT)
116
- return z_q
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from vector_quantize_pytorch import GroupedResidualFSQ
8
+
9
+ from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
10
+
11
+
12
+ @dataclass
13
+ class FSQResult:
14
+ z: torch.Tensor
15
+ codes: torch.Tensor
16
+ latents: torch.Tensor
17
+
18
+
19
+ class DownsampleFiniteScalarQuantize(nn.Module):
20
+ def __init__(
21
+ self,
22
+ input_dim: int = 512,
23
+ n_codebooks: int = 9,
24
+ n_groups: int = 1,
25
+ levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26
+ downsample_factor: tuple[int] = (2, 2),
27
+ downsample_dims: tuple[int] | None = None,
28
+ ):
29
+ super().__init__()
30
+
31
+ if downsample_dims is None:
32
+ downsample_dims = [input_dim for _ in range(len(downsample_factor))]
33
+
34
+ all_dims = (input_dim,) + tuple(downsample_dims)
35
+
36
+ self.residual_fsq = GroupedResidualFSQ(
37
+ dim=all_dims[-1],
38
+ levels=levels,
39
+ num_quantizers=n_codebooks,
40
+ groups=n_groups,
41
+ )
42
+
43
+ self.downsample_factor = downsample_factor
44
+ self.downsample_dims = downsample_dims
45
+
46
+ self.downsample = nn.Sequential(
47
+ *[
48
+ nn.Sequential(
49
+ FishConvNet(
50
+ all_dims[idx],
51
+ all_dims[idx + 1],
52
+ kernel_size=factor,
53
+ stride=factor,
54
+ ),
55
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
56
+ )
57
+ for idx, factor in enumerate(downsample_factor)
58
+ ]
59
+ )
60
+
61
+ self.upsample = nn.Sequential(
62
+ *[
63
+ nn.Sequential(
64
+ FishTransConvNet(
65
+ all_dims[idx + 1],
66
+ all_dims[idx],
67
+ kernel_size=factor,
68
+ stride=factor,
69
+ ),
70
+ ConvNeXtBlock(dim=all_dims[idx]),
71
+ )
72
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
73
+ ]
74
+ )
75
+
76
+ self.apply(self._init_weights)
77
+
78
+ def _init_weights(self, m):
79
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
80
+ nn.init.trunc_normal_(m.weight, std=0.02)
81
+ nn.init.constant_(m.bias, 0)
82
+
83
+ def forward(self, z) -> FSQResult:
84
+ original_shape = z.shape
85
+ z = self.downsample(z)
86
+ quantized, indices = self.residual_fsq(z.mT)
87
+ result = FSQResult(
88
+ z=quantized.mT,
89
+ codes=indices.mT,
90
+ latents=z,
91
+ )
92
+ result.z = self.upsample(result.z)
93
+
94
+ # Pad or crop z to match original shape
95
+ diff = original_shape[-1] - result.z.shape[-1]
96
+ left = diff // 2
97
+ right = diff - left
98
+
99
+ if diff > 0:
100
+ result.z = F.pad(result.z, (left, right))
101
+ elif diff < 0:
102
+ result.z = result.z[..., -left:right]
103
+
104
+ return result
105
+
106
+ def encode(self, z):
107
+ z = self.downsample(z)
108
+ _, indices = self.residual_fsq(z.mT)
109
+ indices = rearrange(indices, "g b l r -> b (g r) l")
110
+ return indices
111
+
112
+ def decode(self, indices: torch.Tensor):
113
+ indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
114
+ z_q = self.residual_fsq.get_output_from_indices(indices)
115
+ z_q = self.upsample(z_q.mT)
116
+ return z_q
fish_speech/models/vqgan/modules/reference.py DELETED
@@ -1,113 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn
6
-
7
- from .wavenet import WaveNet
8
-
9
-
10
- class ReferenceEncoder(WaveNet):
11
- def __init__(
12
- self,
13
- input_channels: Optional[int] = None,
14
- output_channels: Optional[int] = None,
15
- residual_channels: int = 512,
16
- residual_layers: int = 20,
17
- dilation_cycle: Optional[int] = 4,
18
- num_heads: int = 8,
19
- latent_len: int = 4,
20
- ):
21
- super().__init__(
22
- input_channels=input_channels,
23
- residual_channels=residual_channels,
24
- residual_layers=residual_layers,
25
- dilation_cycle=dilation_cycle,
26
- )
27
-
28
- self.head_dim = residual_channels // num_heads
29
- self.num_heads = num_heads
30
-
31
- self.latent_len = latent_len
32
- self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
33
-
34
- self.q = nn.Linear(residual_channels, residual_channels, bias=True)
35
- self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
36
- self.q_norm = nn.LayerNorm(self.head_dim)
37
- self.k_norm = nn.LayerNorm(self.head_dim)
38
- self.proj = nn.Linear(residual_channels, residual_channels)
39
- self.proj_drop = nn.Dropout(0.1)
40
-
41
- self.norm = nn.LayerNorm(residual_channels)
42
- self.mlp = nn.Sequential(
43
- nn.Linear(residual_channels, residual_channels * 4),
44
- nn.SiLU(),
45
- nn.Linear(residual_channels * 4, residual_channels),
46
- )
47
- self.output_projection_attn = nn.Linear(residual_channels, output_channels)
48
-
49
- torch.nn.init.trunc_normal_(self.latent, std=0.02)
50
- self.apply(self.init_weights)
51
-
52
- def init_weights(self, m):
53
- if isinstance(m, nn.Linear):
54
- torch.nn.init.trunc_normal_(m.weight, std=0.02)
55
- if m.bias is not None:
56
- torch.nn.init.constant_(m.bias, 0)
57
-
58
- def forward(self, x, attn_mask=None):
59
- x = super().forward(x).mT
60
- B, N, C = x.shape
61
-
62
- # Calculate mask
63
- if attn_mask is not None:
64
- assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
65
-
66
- attn_mask = attn_mask[:, None, None, :].expand(
67
- B, self.num_heads, self.latent_len, N
68
- )
69
-
70
- q_latent = self.latent.expand(B, -1, -1)
71
- q = (
72
- self.q(q_latent)
73
- .reshape(B, self.latent_len, self.num_heads, self.head_dim)
74
- .transpose(1, 2)
75
- )
76
-
77
- kv = (
78
- self.kv(x)
79
- .reshape(B, N, 2, self.num_heads, self.head_dim)
80
- .permute(2, 0, 3, 1, 4)
81
- )
82
- k, v = kv.unbind(0)
83
-
84
- q, k = self.q_norm(q), self.k_norm(k)
85
- x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
86
-
87
- x = x.transpose(1, 2).reshape(B, self.latent_len, C)
88
- x = self.proj(x)
89
- x = self.proj_drop(x)
90
-
91
- x = x + self.mlp(self.norm(x))
92
- x = self.output_projection_attn(x)
93
- x = x.mean(1)
94
-
95
- return x
96
-
97
-
98
- if __name__ == "__main__":
99
- with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
100
- model = ReferenceEncoder(
101
- input_channels=128,
102
- output_channels=64,
103
- residual_channels=384,
104
- residual_layers=20,
105
- dilation_cycle=4,
106
- num_heads=8,
107
- )
108
- x = torch.randn(4, 128, 64)
109
- mask = torch.ones(4, 64, dtype=torch.bool)
110
- y = model(x, mask)
111
- print(y.shape)
112
- loss = F.mse_loss(y, torch.randn(4, 64))
113
- loss.backward()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/vqgan/modules/wavenet.py DELETED
@@ -1,225 +0,0 @@
1
- import math
2
- from typing import Optional
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from torch import nn
7
-
8
-
9
- class Mish(nn.Module):
10
- def forward(self, x):
11
- return x * torch.tanh(F.softplus(x))
12
-
13
-
14
- class DiffusionEmbedding(nn.Module):
15
- """Diffusion Step Embedding"""
16
-
17
- def __init__(self, d_denoiser):
18
- super(DiffusionEmbedding, self).__init__()
19
- self.dim = d_denoiser
20
-
21
- def forward(self, x):
22
- device = x.device
23
- half_dim = self.dim // 2
24
- emb = math.log(10000) / (half_dim - 1)
25
- emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
26
- emb = x[:, None] * emb[None, :]
27
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
28
- return emb
29
-
30
-
31
- class LinearNorm(nn.Module):
32
- """LinearNorm Projection"""
33
-
34
- def __init__(self, in_features, out_features, bias=False):
35
- super(LinearNorm, self).__init__()
36
- self.linear = nn.Linear(in_features, out_features, bias)
37
-
38
- nn.init.xavier_uniform_(self.linear.weight)
39
- if bias:
40
- nn.init.constant_(self.linear.bias, 0.0)
41
-
42
- def forward(self, x):
43
- x = self.linear(x)
44
- return x
45
-
46
-
47
- class ConvNorm(nn.Module):
48
- """1D Convolution"""
49
-
50
- def __init__(
51
- self,
52
- in_channels,
53
- out_channels,
54
- kernel_size=1,
55
- stride=1,
56
- padding=None,
57
- dilation=1,
58
- bias=True,
59
- w_init_gain="linear",
60
- ):
61
- super(ConvNorm, self).__init__()
62
-
63
- if padding is None:
64
- assert kernel_size % 2 == 1
65
- padding = int(dilation * (kernel_size - 1) / 2)
66
-
67
- self.conv = nn.Conv1d(
68
- in_channels,
69
- out_channels,
70
- kernel_size=kernel_size,
71
- stride=stride,
72
- padding=padding,
73
- dilation=dilation,
74
- bias=bias,
75
- )
76
- nn.init.kaiming_normal_(self.conv.weight)
77
-
78
- def forward(self, signal):
79
- conv_signal = self.conv(signal)
80
-
81
- return conv_signal
82
-
83
-
84
- class ResidualBlock(nn.Module):
85
- """Residual Block"""
86
-
87
- def __init__(
88
- self,
89
- residual_channels,
90
- use_linear_bias=False,
91
- dilation=1,
92
- condition_channels=None,
93
- ):
94
- super(ResidualBlock, self).__init__()
95
- self.conv_layer = ConvNorm(
96
- residual_channels,
97
- 2 * residual_channels,
98
- kernel_size=3,
99
- stride=1,
100
- padding=dilation,
101
- dilation=dilation,
102
- )
103
-
104
- if condition_channels is not None:
105
- self.diffusion_projection = LinearNorm(
106
- residual_channels, residual_channels, use_linear_bias
107
- )
108
- self.condition_projection = ConvNorm(
109
- condition_channels, 2 * residual_channels, kernel_size=1
110
- )
111
-
112
- self.output_projection = ConvNorm(
113
- residual_channels, 2 * residual_channels, kernel_size=1
114
- )
115
-
116
- def forward(self, x, condition=None, diffusion_step=None):
117
- y = x
118
-
119
- if diffusion_step is not None:
120
- diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
121
- y = y + diffusion_step
122
-
123
- y = self.conv_layer(y)
124
-
125
- if condition is not None:
126
- condition = self.condition_projection(condition)
127
- y = y + condition
128
-
129
- gate, filter = torch.chunk(y, 2, dim=1)
130
- y = torch.sigmoid(gate) * torch.tanh(filter)
131
-
132
- y = self.output_projection(y)
133
- residual, skip = torch.chunk(y, 2, dim=1)
134
-
135
- return (x + residual) / math.sqrt(2.0), skip
136
-
137
-
138
- class WaveNet(nn.Module):
139
- def __init__(
140
- self,
141
- input_channels: Optional[int] = None,
142
- output_channels: Optional[int] = None,
143
- residual_channels: int = 512,
144
- residual_layers: int = 20,
145
- dilation_cycle: Optional[int] = 4,
146
- is_diffusion: bool = False,
147
- condition_channels: Optional[int] = None,
148
- ):
149
- super().__init__()
150
-
151
- # Input projection
152
- self.input_projection = None
153
- if input_channels is not None and input_channels != residual_channels:
154
- self.input_projection = ConvNorm(
155
- input_channels, residual_channels, kernel_size=1
156
- )
157
-
158
- if input_channels is None:
159
- input_channels = residual_channels
160
-
161
- self.input_channels = input_channels
162
-
163
- # Residual layers
164
- self.residual_layers = nn.ModuleList(
165
- [
166
- ResidualBlock(
167
- residual_channels=residual_channels,
168
- use_linear_bias=False,
169
- dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
170
- condition_channels=condition_channels,
171
- )
172
- for i in range(residual_layers)
173
- ]
174
- )
175
-
176
- # Skip projection
177
- self.skip_projection = ConvNorm(
178
- residual_channels, residual_channels, kernel_size=1
179
- )
180
-
181
- # Output projection
182
- self.output_projection = None
183
- if output_channels is not None and output_channels != residual_channels:
184
- self.output_projection = ConvNorm(
185
- residual_channels, output_channels, kernel_size=1
186
- )
187
-
188
- if is_diffusion:
189
- self.diffusion_embedding = DiffusionEmbedding(residual_channels)
190
- self.mlp = nn.Sequential(
191
- LinearNorm(residual_channels, residual_channels * 4, False),
192
- Mish(),
193
- LinearNorm(residual_channels * 4, residual_channels, False),
194
- )
195
-
196
- self.apply(self._init_weights)
197
-
198
- def _init_weights(self, m):
199
- if isinstance(m, (nn.Conv1d, nn.Linear)):
200
- nn.init.trunc_normal_(m.weight, std=0.02)
201
- if getattr(m, "bias", None) is not None:
202
- nn.init.constant_(m.bias, 0)
203
-
204
- def forward(self, x, t=None, condition=None):
205
- if self.input_projection is not None:
206
- x = self.input_projection(x)
207
- x = F.silu(x)
208
-
209
- if t is not None:
210
- t = self.diffusion_embedding(t)
211
- t = self.mlp(t)
212
-
213
- skip = []
214
- for layer in self.residual_layers:
215
- x, skip_connection = layer(x, condition, t)
216
- skip.append(skip_connection)
217
-
218
- x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
219
- x = self.skip_projection(x)
220
-
221
- if self.output_projection is not None:
222
- x = F.silu(x)
223
- x = self.output_projection(x)
224
-
225
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/vqgan/spectrogram.py DELETED
@@ -1,122 +0,0 @@
1
- import torch
2
- import torchaudio.functional as F
3
- from torch import Tensor, nn
4
- from torchaudio.transforms import MelScale
5
-
6
-
7
- class LinearSpectrogram(nn.Module):
8
- def __init__(
9
- self,
10
- n_fft=2048,
11
- win_length=2048,
12
- hop_length=512,
13
- center=False,
14
- mode="pow2_sqrt",
15
- ):
16
- super().__init__()
17
-
18
- self.n_fft = n_fft
19
- self.win_length = win_length
20
- self.hop_length = hop_length
21
- self.center = center
22
- self.mode = mode
23
-
24
- self.register_buffer("window", torch.hann_window(win_length), persistent=False)
25
-
26
- def forward(self, y: Tensor) -> Tensor:
27
- if y.ndim == 3:
28
- y = y.squeeze(1)
29
-
30
- y = torch.nn.functional.pad(
31
- y.unsqueeze(1),
32
- (
33
- (self.win_length - self.hop_length) // 2,
34
- (self.win_length - self.hop_length + 1) // 2,
35
- ),
36
- mode="reflect",
37
- ).squeeze(1)
38
-
39
- spec = torch.stft(
40
- y,
41
- self.n_fft,
42
- hop_length=self.hop_length,
43
- win_length=self.win_length,
44
- window=self.window,
45
- center=self.center,
46
- pad_mode="reflect",
47
- normalized=False,
48
- onesided=True,
49
- return_complex=True,
50
- )
51
-
52
- spec = torch.view_as_real(spec)
53
-
54
- if self.mode == "pow2_sqrt":
55
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
56
-
57
- return spec
58
-
59
-
60
- class LogMelSpectrogram(nn.Module):
61
- def __init__(
62
- self,
63
- sample_rate=44100,
64
- n_fft=2048,
65
- win_length=2048,
66
- hop_length=512,
67
- n_mels=128,
68
- center=False,
69
- f_min=0.0,
70
- f_max=None,
71
- ):
72
- super().__init__()
73
-
74
- self.sample_rate = sample_rate
75
- self.n_fft = n_fft
76
- self.win_length = win_length
77
- self.hop_length = hop_length
78
- self.center = center
79
- self.n_mels = n_mels
80
- self.f_min = f_min
81
- self.f_max = f_max or float(sample_rate // 2)
82
-
83
- self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
84
-
85
- fb = F.melscale_fbanks(
86
- n_freqs=self.n_fft // 2 + 1,
87
- f_min=self.f_min,
88
- f_max=self.f_max,
89
- n_mels=self.n_mels,
90
- sample_rate=self.sample_rate,
91
- norm="slaney",
92
- mel_scale="slaney",
93
- )
94
- self.register_buffer(
95
- "fb",
96
- fb,
97
- persistent=False,
98
- )
99
-
100
- def compress(self, x: Tensor) -> Tensor:
101
- return torch.log(torch.clamp(x, min=1e-5))
102
-
103
- def decompress(self, x: Tensor) -> Tensor:
104
- return torch.exp(x)
105
-
106
- def apply_mel_scale(self, x: Tensor) -> Tensor:
107
- return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
108
-
109
- def forward(
110
- self, x: Tensor, return_linear: bool = False, sample_rate: int = None
111
- ) -> Tensor:
112
- if sample_rate is not None and sample_rate != self.sample_rate:
113
- x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
114
-
115
- linear = self.spectrogram(x)
116
- x = self.apply_mel_scale(linear)
117
- x = self.compress(x)
118
-
119
- if return_linear:
120
- return x, self.compress(linear)
121
-
122
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/vqgan/utils.py CHANGED
@@ -1,94 +1,94 @@
1
- import matplotlib
2
- import torch
3
- from matplotlib import pyplot as plt
4
-
5
- matplotlib.use("Agg")
6
-
7
-
8
- def convert_pad_shape(pad_shape):
9
- l = pad_shape[::-1]
10
- pad_shape = [item for sublist in l for item in sublist]
11
- return pad_shape
12
-
13
-
14
- def sequence_mask(length, max_length=None):
15
- if max_length is None:
16
- max_length = length.max()
17
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
- return x.unsqueeze(0) < length.unsqueeze(1)
19
-
20
-
21
- def init_weights(m, mean=0.0, std=0.01):
22
- classname = m.__class__.__name__
23
- if classname.find("Conv") != -1:
24
- m.weight.data.normal_(mean, std)
25
-
26
-
27
- def get_padding(kernel_size, dilation=1):
28
- return int((kernel_size * dilation - dilation) / 2)
29
-
30
-
31
- def plot_mel(data, titles=None):
32
- fig, axes = plt.subplots(len(data), 1, squeeze=False)
33
-
34
- if titles is None:
35
- titles = [None for i in range(len(data))]
36
-
37
- plt.tight_layout()
38
-
39
- for i in range(len(data)):
40
- mel = data[i]
41
-
42
- if isinstance(mel, torch.Tensor):
43
- mel = mel.float().detach().cpu().numpy()
44
-
45
- axes[i][0].imshow(mel, origin="lower")
46
- axes[i][0].set_aspect(2.5, adjustable="box")
47
- axes[i][0].set_ylim(0, mel.shape[0])
48
- axes[i][0].set_title(titles[i], fontsize="medium")
49
- axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
50
- axes[i][0].set_anchor("W")
51
-
52
- return fig
53
-
54
-
55
- def slice_segments(x, ids_str, segment_size=4):
56
- ret = torch.zeros_like(x[:, :, :segment_size])
57
- for i in range(x.size(0)):
58
- idx_str = ids_str[i]
59
- idx_end = idx_str + segment_size
60
- ret[i] = x[i, :, idx_str:idx_end]
61
-
62
- return ret
63
-
64
-
65
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
66
- b, d, t = x.size()
67
- if x_lengths is None:
68
- x_lengths = t
69
- ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
70
- ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
71
- ret = slice_segments(x, ids_str, segment_size)
72
- return ret, ids_str
73
-
74
-
75
- @torch.jit.script
76
- def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
77
- n_channels_int = n_channels[0]
78
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
79
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
80
- acts = t_act * s_act
81
-
82
- return acts
83
-
84
-
85
- def avg_with_mask(x, mask):
86
- assert mask.dtype == torch.float, "Mask should be float"
87
-
88
- if mask.ndim == 2:
89
- mask = mask.unsqueeze(1)
90
-
91
- if mask.shape[1] == 1:
92
- mask = mask.expand_as(x)
93
-
94
- return (x * mask).sum() / mask.sum()
 
1
+ import matplotlib
2
+ import torch
3
+ from matplotlib import pyplot as plt
4
+
5
+ matplotlib.use("Agg")
6
+
7
+
8
+ def convert_pad_shape(pad_shape):
9
+ l = pad_shape[::-1]
10
+ pad_shape = [item for sublist in l for item in sublist]
11
+ return pad_shape
12
+
13
+
14
+ def sequence_mask(length, max_length=None):
15
+ if max_length is None:
16
+ max_length = length.max()
17
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
+ return x.unsqueeze(0) < length.unsqueeze(1)
19
+
20
+
21
+ def init_weights(m, mean=0.0, std=0.01):
22
+ classname = m.__class__.__name__
23
+ if classname.find("Conv") != -1:
24
+ m.weight.data.normal_(mean, std)
25
+
26
+
27
+ def get_padding(kernel_size, dilation=1):
28
+ return int((kernel_size * dilation - dilation) / 2)
29
+
30
+
31
+ def plot_mel(data, titles=None):
32
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
33
+
34
+ if titles is None:
35
+ titles = [None for i in range(len(data))]
36
+
37
+ plt.tight_layout()
38
+
39
+ for i in range(len(data)):
40
+ mel = data[i]
41
+
42
+ if isinstance(mel, torch.Tensor):
43
+ mel = mel.float().detach().cpu().numpy()
44
+
45
+ axes[i][0].imshow(mel, origin="lower")
46
+ axes[i][0].set_aspect(2.5, adjustable="box")
47
+ axes[i][0].set_ylim(0, mel.shape[0])
48
+ axes[i][0].set_title(titles[i], fontsize="medium")
49
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
50
+ axes[i][0].set_anchor("W")
51
+
52
+ return fig
53
+
54
+
55
+ def slice_segments(x, ids_str, segment_size=4):
56
+ ret = torch.zeros_like(x[:, :, :segment_size])
57
+ for i in range(x.size(0)):
58
+ idx_str = ids_str[i]
59
+ idx_end = idx_str + segment_size
60
+ ret[i] = x[i, :, idx_str:idx_end]
61
+
62
+ return ret
63
+
64
+
65
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
66
+ b, d, t = x.size()
67
+ if x_lengths is None:
68
+ x_lengths = t
69
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
70
+ ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
71
+ ret = slice_segments(x, ids_str, segment_size)
72
+ return ret, ids_str
73
+
74
+
75
+ @torch.jit.script
76
+ def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
77
+ n_channels_int = n_channels[0]
78
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
79
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
80
+ acts = t_act * s_act
81
+
82
+ return acts
83
+
84
+
85
+ def avg_with_mask(x, mask):
86
+ assert mask.dtype == torch.float, "Mask should be float"
87
+
88
+ if mask.ndim == 2:
89
+ mask = mask.unsqueeze(1)
90
+
91
+ if mask.shape[1] == 1:
92
+ mask = mask.expand_as(x)
93
+
94
+ return (x * mask).sum() / mask.sum()
fish_speech/scheduler.py CHANGED
@@ -1,40 +1,40 @@
1
- import math
2
-
3
-
4
- def get_cosine_schedule_with_warmup_lr_lambda(
5
- current_step: int,
6
- *,
7
- num_warmup_steps: int | float,
8
- num_training_steps: int,
9
- num_cycles: float = 0.5,
10
- final_lr_ratio: float = 0.0,
11
- ):
12
- if 0 < num_warmup_steps < 1: # float mode
13
- num_warmup_steps = int(num_warmup_steps * num_training_steps)
14
-
15
- if current_step < num_warmup_steps:
16
- return float(current_step) / float(max(1, num_warmup_steps))
17
-
18
- progress = float(current_step - num_warmup_steps) / float(
19
- max(1, num_training_steps - num_warmup_steps)
20
- )
21
-
22
- return max(
23
- final_lr_ratio,
24
- 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
25
- )
26
-
27
-
28
- def get_constant_schedule_with_warmup_lr_lambda(
29
- current_step: int,
30
- *,
31
- num_warmup_steps: int | float,
32
- num_training_steps: int | None = None,
33
- ):
34
- if 0 < num_warmup_steps < 1: # float mode
35
- num_warmup_steps = int(num_warmup_steps * num_training_steps)
36
-
37
- if current_step < num_warmup_steps:
38
- return float(current_step) / float(max(1, num_warmup_steps))
39
-
40
- return 1.0
 
1
+ import math
2
+
3
+
4
+ def get_cosine_schedule_with_warmup_lr_lambda(
5
+ current_step: int,
6
+ *,
7
+ num_warmup_steps: int | float,
8
+ num_training_steps: int,
9
+ num_cycles: float = 0.5,
10
+ final_lr_ratio: float = 0.0,
11
+ ):
12
+ if 0 < num_warmup_steps < 1: # float mode
13
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
14
+
15
+ if current_step < num_warmup_steps:
16
+ return float(current_step) / float(max(1, num_warmup_steps))
17
+
18
+ progress = float(current_step - num_warmup_steps) / float(
19
+ max(1, num_training_steps - num_warmup_steps)
20
+ )
21
+
22
+ return max(
23
+ final_lr_ratio,
24
+ 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
25
+ )
26
+
27
+
28
+ def get_constant_schedule_with_warmup_lr_lambda(
29
+ current_step: int,
30
+ *,
31
+ num_warmup_steps: int | float,
32
+ num_training_steps: int | None = None,
33
+ ):
34
+ if 0 < num_warmup_steps < 1: # float mode
35
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
36
+
37
+ if current_step < num_warmup_steps:
38
+ return float(current_step) / float(max(1, num_warmup_steps))
39
+
40
+ return 1.0
fish_speech/text/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .clean import clean_text
2
- from .spliter import split_text
3
-
4
- __all__ = ["clean_text", "split_text"]
 
1
+ from .clean import clean_text
2
+ from .spliter import split_text
3
+
4
+ __all__ = ["clean_text", "split_text"]
fish_speech/text/chn_text_norm/.gitignore CHANGED
@@ -1,114 +1,114 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- *.egg-info/
24
- .installed.cfg
25
- *.egg
26
- MANIFEST
27
-
28
- # PyInstaller
29
- # Usually these files are written by a python script from a template
30
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
- *.manifest
32
- *.spec
33
-
34
- # Installer logs
35
- pip-log.txt
36
- pip-delete-this-directory.txt
37
-
38
- # Unit test / coverage reports
39
- htmlcov/
40
- .tox/
41
- .coverage
42
- .coverage.*
43
- .cache
44
- nosetests.xml
45
- coverage.xml
46
- *.cover
47
- .hypothesis/
48
- .pytest_cache/
49
-
50
- # Translations
51
- *.mo
52
- *.pot
53
-
54
- # Django stuff:
55
- *.log
56
- local_settings.py
57
- db.sqlite3
58
-
59
- # Flask stuff:
60
- instance/
61
- .webassets-cache
62
-
63
- # Scrapy stuff:
64
- .scrapy
65
-
66
- # Sphinx documentation
67
- docs/_build/
68
-
69
- # PyBuilder
70
- target/
71
-
72
- # Jupyter Notebook
73
- .ipynb_checkpoints
74
-
75
- # pyenv
76
- .python-version
77
-
78
- # celery beat schedule file
79
- celerybeat-schedule
80
-
81
- # SageMath parsed files
82
- *.sage.py
83
-
84
- # Environments
85
- .env
86
- .venv
87
- env/
88
- venv/
89
- ENV/
90
- env.bak/
91
- venv.bak/
92
-
93
- # Spyder project settings
94
- .spyderproject
95
- .spyproject
96
-
97
- # Rope project settings
98
- .ropeproject
99
-
100
- # mkdocs documentation
101
- /site
102
-
103
- # mypy
104
- .mypy_cache/
105
-
106
- # JetBrains PyCharm
107
- .idea
108
-
109
- # Customize
110
- references
111
- url.txt
112
-
113
- # Git
114
- .git
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # pyenv
76
+ .python-version
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # SageMath parsed files
82
+ *.sage.py
83
+
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
105
+
106
+ # JetBrains PyCharm
107
+ .idea
108
+
109
+ # Customize
110
+ references
111
+ url.txt
112
+
113
+ # Git
114
+ .git
fish_speech/text/chn_text_norm/README.md CHANGED
@@ -1,36 +1,36 @@
1
- # This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
2
-
3
- # Chn Text Norm
4
-
5
- this is a repository for chinese text normalization (no longer maintained).
6
-
7
- ## Quick Start ##
8
-
9
- ### Git Clone Repo ###
10
-
11
- git clone this repo to the root directory of your project which need to use it.
12
-
13
- cd /path/to/proj
14
- git clone https://github.com/Joee1995/chn-text-norm.git
15
-
16
- after that, your doc tree should be:
17
- ```
18
- proj # root of your project
19
- |--- chn_text_norm # this chn-text-norm tool
20
- |--- text.py
21
- |--- ...
22
- |--- text_normalize.py # your text normalization code
23
- |--- ...
24
- ```
25
-
26
- ### How to Use ? ###
27
-
28
- # text_normalize.py
29
- from chn_text_norm.text import *
30
-
31
- raw_text = 'your raw text'
32
- text = Text(raw_text=raw_text).normalize()
33
-
34
- ### How to add quantums ###
35
-
36
- 打开test.py,然后你就知道怎么做了。
 
1
+ # This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
2
+
3
+ # Chn Text Norm
4
+
5
+ this is a repository for chinese text normalization (no longer maintained).
6
+
7
+ ## Quick Start ##
8
+
9
+ ### Git Clone Repo ###
10
+
11
+ git clone this repo to the root directory of your project which need to use it.
12
+
13
+ cd /path/to/proj
14
+ git clone https://github.com/Joee1995/chn-text-norm.git
15
+
16
+ after that, your doc tree should be:
17
+ ```
18
+ proj # root of your project
19
+ |--- chn_text_norm # this chn-text-norm tool
20
+ |--- text.py
21
+ |--- ...
22
+ |--- text_normalize.py # your text normalization code
23
+ |--- ...
24
+ ```
25
+
26
+ ### How to Use ? ###
27
+
28
+ # text_normalize.py
29
+ from chn_text_norm.text import *
30
+
31
+ raw_text = 'your raw text'
32
+ text = Text(raw_text=raw_text).normalize()
33
+
34
+ ### How to add quantums ###
35
+
36
+ 打开test.py,然后你就知道怎么做了。
fish_speech/text/chn_text_norm/basic_class.py CHANGED
@@ -1,172 +1,172 @@
1
- # -*- coding: utf-8 -*-
2
- """基本类
3
- 中文字符类
4
- 中文数字/数位类
5
- 中文数字类
6
- 中文数位类
7
- 中文数字系统类
8
- 中文数学符号类
9
- *中文其他符号类
10
- """
11
-
12
- __author__ = "Zhiyang Zhou <[email protected]>"
13
- __data__ = "2019-05-02"
14
-
15
- from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
16
-
17
-
18
- class ChineseChar(object):
19
- """
20
- 中文字符
21
- 每个字符对应简体和繁体,
22
- e.g. 简体 = '负', 繁体 = '負'
23
- 转换时可转换为简体或繁体
24
- """
25
-
26
- def __init__(self, simplified, traditional):
27
- self.simplified = simplified
28
- self.traditional = traditional
29
- self.__repr__ = self.__str__
30
-
31
- def __str__(self):
32
- return self.simplified or self.traditional or None
33
-
34
- def __repr__(self):
35
- return self.__str__()
36
-
37
-
38
- class ChineseNumberUnit(ChineseChar):
39
- """
40
- 中文数字/数位字符
41
- 每个字符除繁简体外还有一个额外的大写字符
42
- e.g. '陆' 和 '陸'
43
- """
44
-
45
- def __init__(self, power, simplified, traditional, big_s, big_t):
46
- super(ChineseNumberUnit, self).__init__(simplified, traditional)
47
- self.power = power
48
- self.big_s = big_s
49
- self.big_t = big_t
50
-
51
- def __str__(self):
52
- return "10^{}".format(self.power)
53
-
54
- @classmethod
55
- def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
56
-
57
- if small_unit:
58
- return ChineseNumberUnit(
59
- power=index + 1,
60
- simplified=value[0],
61
- traditional=value[1],
62
- big_s=value[1],
63
- big_t=value[1],
64
- )
65
- elif numbering_type == NUMBERING_TYPES[0]:
66
- return ChineseNumberUnit(
67
- power=index + 8,
68
- simplified=value[0],
69
- traditional=value[1],
70
- big_s=value[0],
71
- big_t=value[1],
72
- )
73
- elif numbering_type == NUMBERING_TYPES[1]:
74
- return ChineseNumberUnit(
75
- power=(index + 2) * 4,
76
- simplified=value[0],
77
- traditional=value[1],
78
- big_s=value[0],
79
- big_t=value[1],
80
- )
81
- elif numbering_type == NUMBERING_TYPES[2]:
82
- return ChineseNumberUnit(
83
- power=pow(2, index + 3),
84
- simplified=value[0],
85
- traditional=value[1],
86
- big_s=value[0],
87
- big_t=value[1],
88
- )
89
- else:
90
- raise ValueError(
91
- "Counting type should be in {0} ({1} provided).".format(
92
- NUMBERING_TYPES, numbering_type
93
- )
94
- )
95
-
96
-
97
- class ChineseNumberDigit(ChineseChar):
98
- """
99
- 中文数字字符
100
- """
101
-
102
- def __init__(
103
- self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
104
- ):
105
- super(ChineseNumberDigit, self).__init__(simplified, traditional)
106
- self.value = value
107
- self.big_s = big_s
108
- self.big_t = big_t
109
- self.alt_s = alt_s
110
- self.alt_t = alt_t
111
-
112
- def __str__(self):
113
- return str(self.value)
114
-
115
- @classmethod
116
- def create(cls, i, v):
117
- return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
118
-
119
-
120
- class ChineseMath(ChineseChar):
121
- """
122
- 中文数位字符
123
- """
124
-
125
- def __init__(self, simplified, traditional, symbol, expression=None):
126
- super(ChineseMath, self).__init__(simplified, traditional)
127
- self.symbol = symbol
128
- self.expression = expression
129
- self.big_s = simplified
130
- self.big_t = traditional
131
-
132
-
133
- CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
134
-
135
-
136
- class NumberSystem(object):
137
- """
138
- 中文数字系统
139
- """
140
-
141
- pass
142
-
143
-
144
- class MathSymbol(object):
145
- """
146
- 用于中文数字系统的数学符号 (繁/简体), e.g.
147
- positive = ['正', '正']
148
- negative = ['负', '負']
149
- point = ['点', '點']
150
- """
151
-
152
- def __init__(self, positive, negative, point):
153
- self.positive = positive
154
- self.negative = negative
155
- self.point = point
156
-
157
- def __iter__(self):
158
- for v in self.__dict__.values():
159
- yield v
160
-
161
-
162
- # class OtherSymbol(object):
163
- # """
164
- # 其他符号
165
- # """
166
- #
167
- # def __init__(self, sil):
168
- # self.sil = sil
169
- #
170
- # def __iter__(self):
171
- # for v in self.__dict__.values():
172
- # yield v
 
1
+ # -*- coding: utf-8 -*-
2
+ """基本类
3
+ 中文字符类
4
+ 中文数字/数位类
5
+ 中文数字类
6
+ 中文数位类
7
+ 中文数字系统类
8
+ 中文数学符号类
9
+ *中文其他符号类
10
+ """
11
+
12
+ __author__ = "Zhiyang Zhou <[email protected]>"
13
+ __data__ = "2019-05-02"
14
+
15
+ from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
16
+
17
+
18
+ class ChineseChar(object):
19
+ """
20
+ 中文字符
21
+ 每个字符对应简体和繁体,
22
+ e.g. 简体 = '负', 繁体 = '負'
23
+ 转换时可转换为简���或繁体
24
+ """
25
+
26
+ def __init__(self, simplified, traditional):
27
+ self.simplified = simplified
28
+ self.traditional = traditional
29
+ self.__repr__ = self.__str__
30
+
31
+ def __str__(self):
32
+ return self.simplified or self.traditional or None
33
+
34
+ def __repr__(self):
35
+ return self.__str__()
36
+
37
+
38
+ class ChineseNumberUnit(ChineseChar):
39
+ """
40
+ 中文数字/数位字符
41
+ 每个字符除繁简体外还有一个额外的大写字符
42
+ e.g. '陆' 和 '陸'
43
+ """
44
+
45
+ def __init__(self, power, simplified, traditional, big_s, big_t):
46
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
47
+ self.power = power
48
+ self.big_s = big_s
49
+ self.big_t = big_t
50
+
51
+ def __str__(self):
52
+ return "10^{}".format(self.power)
53
+
54
+ @classmethod
55
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
56
+
57
+ if small_unit:
58
+ return ChineseNumberUnit(
59
+ power=index + 1,
60
+ simplified=value[0],
61
+ traditional=value[1],
62
+ big_s=value[1],
63
+ big_t=value[1],
64
+ )
65
+ elif numbering_type == NUMBERING_TYPES[0]:
66
+ return ChineseNumberUnit(
67
+ power=index + 8,
68
+ simplified=value[0],
69
+ traditional=value[1],
70
+ big_s=value[0],
71
+ big_t=value[1],
72
+ )
73
+ elif numbering_type == NUMBERING_TYPES[1]:
74
+ return ChineseNumberUnit(
75
+ power=(index + 2) * 4,
76
+ simplified=value[0],
77
+ traditional=value[1],
78
+ big_s=value[0],
79
+ big_t=value[1],
80
+ )
81
+ elif numbering_type == NUMBERING_TYPES[2]:
82
+ return ChineseNumberUnit(
83
+ power=pow(2, index + 3),
84
+ simplified=value[0],
85
+ traditional=value[1],
86
+ big_s=value[0],
87
+ big_t=value[1],
88
+ )
89
+ else:
90
+ raise ValueError(
91
+ "Counting type should be in {0} ({1} provided).".format(
92
+ NUMBERING_TYPES, numbering_type
93
+ )
94
+ )
95
+
96
+
97
+ class ChineseNumberDigit(ChineseChar):
98
+ """
99
+ 中文数字字符
100
+ """
101
+
102
+ def __init__(
103
+ self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
104
+ ):
105
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
106
+ self.value = value
107
+ self.big_s = big_s
108
+ self.big_t = big_t
109
+ self.alt_s = alt_s
110
+ self.alt_t = alt_t
111
+
112
+ def __str__(self):
113
+ return str(self.value)
114
+
115
+ @classmethod
116
+ def create(cls, i, v):
117
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
118
+
119
+
120
+ class ChineseMath(ChineseChar):
121
+ """
122
+ 中文数位字符
123
+ """
124
+
125
+ def __init__(self, simplified, traditional, symbol, expression=None):
126
+ super(ChineseMath, self).__init__(simplified, traditional)
127
+ self.symbol = symbol
128
+ self.expression = expression
129
+ self.big_s = simplified
130
+ self.big_t = traditional
131
+
132
+
133
+ CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
134
+
135
+
136
+ class NumberSystem(object):
137
+ """
138
+ 中文数字系统
139
+ """
140
+
141
+ pass
142
+
143
+
144
+ class MathSymbol(object):
145
+ """
146
+ 用于中文数字系统的数学符号 (繁/简体), e.g.
147
+ positive = ['正', '正']
148
+ negative = ['负', '負']
149
+ point = ['点', '點']
150
+ """
151
+
152
+ def __init__(self, positive, negative, point):
153
+ self.positive = positive
154
+ self.negative = negative
155
+ self.point = point
156
+
157
+ def __iter__(self):
158
+ for v in self.__dict__.values():
159
+ yield v
160
+
161
+
162
+ # class OtherSymbol(object):
163
+ # """
164
+ # 其他符号
165
+ # """
166
+ #
167
+ # def __init__(self, sil):
168
+ # self.sil = sil
169
+ #
170
+ # def __iter__(self):
171
+ # for v in self.__dict__.values():
172
+ # yield v