diff --git a/API_FLAGS.txt b/API_FLAGS.txt
new file mode 100644
index 0000000000000000000000000000000000000000..96a6beea027f310ac7a9e3a72e31cfddc47e47c0
--- /dev/null
+++ b/API_FLAGS.txt
@@ -0,0 +1,6 @@
+# --infer
+--api
+--listen 0.0.0.0:8080 \
+--llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+--decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+--decoder-config-name firefly_gan_vq
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..b7c181f7bb80a927c2f0f539bebad6e5483b595c
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2024 Fish Audio
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 9200fb787897a02778f4481e4f1e510a15612aa8..a091163ee2db8c0a6119eb40900333297cd374b8 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,140 @@
+
+
Fish Speech
+
+**English** | [简体中文](docs/README.zh.md) | [Portuguese](docs/README.pt-BR.md) | [日本語](docs/README.ja.md) | [한국어](docs/README.ko.md)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+This codebase is released under Apache License and all model weights are released under CC-BY-NC-SA-4.0 License. Please refer to [LICENSE](LICENSE) for more details.
+
---
-title: Fish Audio T
-emoji: 🐨
-colorFrom: red
-colorTo: gray
-sdk: gradio
-sdk_version: 5.9.1
-app_file: app.py
-pinned: false
-license: unknown
-short_description: fish-audio测试
----
+## Fish Agent
+We are very excited to announce that we have made our self-research agent demo open source, you can now try our agent demo online at [demo](https://fish.audio/demo/live) for instant English chat and English and Chinese chat locally by following the [docs](https://speech.fish.audio/start_agent/).
+
+You should mention that the content is released under a **CC BY-NC-SA 4.0 licence**. And the demo is an early alpha test version, the inference speed needs to be optimised, and there are a lot of bugs waiting to be fixed. If you've found a bug or want to fix it, we'd be very happy to receive an issue or a pull request.
+
+## Features
+### Fish Speech
+
+1. **Zero-shot & Few-shot TTS:** Input a 10 to 30-second vocal sample to generate high-quality TTS output. **For detailed guidelines, see [Voice Cloning Best Practices](https://docs.fish.audio/text-to-speech/voice-clone-best-practices).**
+
+2. **Multilingual & Cross-lingual Support:** Simply copy and paste multilingual text into the input box—no need to worry about the language. Currently supports English, Japanese, Korean, Chinese, French, German, Arabic, and Spanish.
+
+3. **No Phoneme Dependency:** The model has strong generalization capabilities and does not rely on phonemes for TTS. It can handle text in any language script.
+
+4. **Highly Accurate:** Achieves a low CER (Character Error Rate) and WER (Word Error Rate) of around 2% for 5-minute English texts.
+
+5. **Fast:** With fish-tech acceleration, the real-time factor is approximately 1:5 on an Nvidia RTX 4060 laptop and 1:15 on an Nvidia RTX 4090.
+
+6. **WebUI Inference:** Features an easy-to-use, Gradio-based web UI compatible with Chrome, Firefox, Edge, and other browsers.
+
+7. **GUI Inference:** Offers a PyQt6 graphical interface that works seamlessly with the API server. Supports Linux, Windows, and macOS. [See GUI](https://github.com/AnyaCoder/fish-speech-gui).
+
+8. **Deploy-Friendly:** Easily set up an inference server with native support for Linux, Windows and MacOS, minimizing speed loss.
+
+### Fish Agent
+1. **Completely End to End:** Automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
+
+2. **Timbre Control:** Can use reference audio to control the speech timbre.
+
+3. **Emotional:** The model can generate speech with strong emotion.
+
+## Disclaimer
+
+We do not hold any responsibility for any illegal usage of the codebase. Please refer to your local laws about DMCA and other related laws.
+
+## Online Demo
+
+[Fish Audio](https://fish.audio)
+
+[Fish Agent](https://fish.audio/demo/live)
+
+## Quick Start for Local Inference
+
+[inference.ipynb](/inference.ipynb)
+
+## Videos
+
+#### V1.4 Demo Video: [Youtube](https://www.youtube.com/watch?v=Ghc8cJdQyKQ)
+
+## Documents
+
+- [English](https://speech.fish.audio/)
+- [中文](https://speech.fish.audio/zh/)
+- [日本語](https://speech.fish.audio/ja/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/)
+
+## Samples (2024/10/02 V1.4)
+
+- [English](https://speech.fish.audio/samples/)
+- [中文](https://speech.fish.audio/zh/samples/)
+- [日本語](https://speech.fish.audio/ja/samples/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/)
+
+## Credits
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## Tech Report (V1.4)
+```bibtex
+@misc{fish-speech-v1.4,
+ title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
+ author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
+ year={2024},
+ eprint={2411.01156},
+ archivePrefix={arXiv},
+ primaryClass={cs.SD},
+ url={https://arxiv.org/abs/2411.01156},
+}
+```
+
+## Sponsor
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a469e466b87612c968187cb07a5a977e35284f2
--- /dev/null
+++ b/app.py
@@ -0,0 +1,104 @@
+import os
+from argparse import ArgumentParser
+from pathlib import Path
+
+import pyrootutils
+import torch
+from loguru import logger
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from tools.inference_engine import TTSInferenceEngine
+from tools.llama.generate import launch_thread_safe_queue
+from tools.schema import ServeTTSRequest
+from tools.vqgan.inference import load_model as load_decoder_model
+from tools.webui import build_app
+from tools.webui.inference import get_inference_wrapper
+
+# Make einx happy
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.5",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-gradio-length", type=int, default=0)
+ parser.add_argument("--theme", type=str, default="light")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ args.precision = torch.half if args.half else torch.bfloat16
+
+ # Check if MPS or CUDA is available
+ if torch.backends.mps.is_available():
+ args.device = "mps"
+ logger.info("mps is available, running on mps.")
+ elif not torch.cuda.is_available():
+ logger.info("CUDA is not available, running on CPU.")
+ args.device = "cpu"
+
+ logger.info("Loading Llama model...")
+ llama_queue = launch_thread_safe_queue(
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ compile=args.compile,
+ )
+
+ logger.info("Loading VQ-GAN model...")
+ decoder_model = load_decoder_model(
+ config_name=args.decoder_config_name,
+ checkpoint_path=args.decoder_checkpoint_path,
+ device=args.device,
+ )
+
+ logger.info("Decoder model loaded, warming up...")
+
+ # Create the inference engine
+ inference_engine = TTSInferenceEngine(
+ llama_queue=llama_queue,
+ decoder_model=decoder_model,
+ compile=args.compile,
+ precision=args.precision,
+ )
+
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
+ list(
+ inference_engine.inference(
+ ServeTTSRequest(
+ text="Hello world.",
+ references=[],
+ reference_id=None,
+ max_new_tokens=1024,
+ chunk_length=200,
+ top_p=0.7,
+ repetition_penalty=1.5,
+ temperature=0.7,
+ format="wav",
+ )
+ )
+ )
+
+ logger.info("Warming up done, launching the web UI...")
+
+ # Get the inference function with the immutable arguments
+ inference_fct = get_inference_wrapper(inference_engine)
+
+ app = build_app(inference_fct, args.theme)
+ app.launch(show_api=True, share=True)
diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml
new file mode 100644
index 0000000000000000000000000000000000000000..cbd2ba51ca6c62d967333a86981641e17c3b67ef
--- /dev/null
+++ b/docker-compose.dev.yml
@@ -0,0 +1,18 @@
+version: '3.8'
+
+services:
+ fish-speech:
+ build:
+ context: .
+ dockerfile: dockerfile.dev
+ container_name: fish-speech
+ volumes:
+ - ./:/exp
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: all
+ capabilities: [gpu]
+ command: tail -f /dev/null
diff --git a/dockerfile b/dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..9c716fe5388662894e730d8106e560aac8dbaa23
--- /dev/null
+++ b/dockerfile
@@ -0,0 +1,50 @@
+FROM python:3.12-slim-bookworm AS stage-1
+ARG TARGETARCH
+
+ARG HUGGINGFACE_MODEL=fish-speech-1.5
+ARG HF_ENDPOINT=https://huggingface.co
+
+WORKDIR /opt/fish-speech
+
+RUN set -ex \
+ && pip install huggingface_hub \
+ && HF_ENDPOINT=${HF_ENDPOINT} huggingface-cli download --resume-download fishaudio/${HUGGINGFACE_MODEL} --local-dir checkpoints/${HUGGINGFACE_MODEL}
+
+FROM python:3.12-slim-bookworm
+ARG TARGETARCH
+
+ARG DEPENDENCIES=" \
+ ca-certificates \
+ libsox-dev \
+ build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0 \
+ ffmpeg"
+
+RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
+ --mount=type=cache,target=/var/lib/apt,sharing=locked \
+ set -ex \
+ && rm -f /etc/apt/apt.conf.d/docker-clean \
+ && echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache \
+ && apt-get update \
+ && apt-get -y install --no-install-recommends ${DEPENDENCIES} \
+ && echo "no" | dpkg-reconfigure dash
+
+WORKDIR /opt/fish-speech
+
+COPY . .
+
+RUN --mount=type=cache,target=/root/.cache,sharing=locked \
+ set -ex \
+ && pip install -e .[stable]
+
+COPY --from=stage-1 /opt/fish-speech/checkpoints /opt/fish-speech/checkpoints
+
+ENV GRADIO_SERVER_NAME="0.0.0.0"
+
+EXPOSE 7860
+
+CMD ["./entrypoint.sh"]
diff --git a/dockerfile.dev b/dockerfile.dev
new file mode 100644
index 0000000000000000000000000000000000000000..ac5d18f6a6053ba758dcbc557a4b8d5d6eacf09b
--- /dev/null
+++ b/dockerfile.dev
@@ -0,0 +1,37 @@
+ARG VERSION=dev
+ARG BASE_IMAGE=ghcr.io/fishaudio/fish-speech:${VERSION}
+
+FROM ${BASE_IMAGE}
+
+ARG TOOLS=" \
+ git \
+ curl \
+ build-essential \
+ ffmpeg \
+ libsm6 \
+ libxext6 \
+ libjpeg-dev \
+ zlib1g-dev \
+ aria2 \
+ zsh \
+ openssh-server \
+ sudo \
+ protobuf-compiler \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0 \
+ cmake"
+
+RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
+ --mount=type=cache,target=/var/lib/apt,sharing=locked \
+ set -ex \
+ && apt-get update \
+ && apt-get -y install --no-install-recommends ${TOOLS}
+
+# Install oh-my-zsh so your terminal looks nice
+RUN sh -c "$(curl https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" "" --unattended
+
+# Set zsh as default shell
+RUN chsh -s /usr/bin/zsh
+ENV SHELL=/usr/bin/zsh
diff --git a/docs/CNAME b/docs/CNAME
new file mode 100644
index 0000000000000000000000000000000000000000..d506fb8b394fa80f3d329ab8450dfc102e839bd1
--- /dev/null
+++ b/docs/CNAME
@@ -0,0 +1 @@
+speech.fish.audio
diff --git a/docs/README.ja.md b/docs/README.ja.md
new file mode 100644
index 0000000000000000000000000000000000000000..e0872e988edec972a874d7c52327d3fa9380c578
--- /dev/null
+++ b/docs/README.ja.md
@@ -0,0 +1,106 @@
+
+
Fish Speech
+
+[English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | **日本語** | [한국어](README.ko.md)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+このコードベースとすべてのモデルは、CC-BY-NC-SA-4.0 ライセンスの下でリリースされています。詳細については、[LICENSE](LICENSE)を参照してください。
+
+---
+
+## 機能
+
+1. **ゼロショット & フューショット TTS**:10〜30 秒の音声サンプルを入力して、高品質の TTS 出力を生成します。**詳細は [音声クローンのベストプラクティス](https://docs.fish.audio/text-to-speech/voice-clone-best-practices) を参照してください。**
+2. **多言語 & クロスリンガル対応**:多言語テキストを入力ボックスにコピーペーストするだけで、言語を気にする必要はありません。現在、英語、日本語、韓国語、中国語、フランス語、ドイツ語、アラビア語、スペイン語に対応しています。
+3. **音素依存なし**:このモデルは強力な汎化能力を持ち、TTS に音素を必要としません。あらゆる言語スクリプトに対応可能です。
+4. **高精度**:5 分間の英語テキストに対し、CER(文字誤り率)と WER(単語誤り率)は約 2%の精度を達成します。
+5. **高速**:fish-tech アクセラレーションにより、Nvidia RTX 4060 ラップトップではリアルタイムファクターが約 1:5、Nvidia RTX 4090 では約 1:15 です。
+6. **WebUI 推論**:使いやすい Gradio ベースの Web ユーザーインターフェースを搭載し、Chrome、Firefox、Edge などのブラウザに対応しています。
+7. **GUI 推論**:PyQt6 のグラフィカルインターフェースを提供し、API サーバーとシームレスに連携します。Linux、Windows、macOS に対応しています。[GUI を見る](https://github.com/AnyaCoder/fish-speech-gui)。
+8. **デプロイしやすい**:Linux、Windows、macOS にネイティブ対応した推論サーバーを簡単にセットアップでき、速度の低下を最小限に抑えます。
+
+## 免責事項
+
+コードベースの違法な使用については一切責任を負いません。DMCA(デジタルミレニアム著作権法)およびその他の関連法については、地域の法律を参照してください。
+
+## オンラインデモ
+
+[Fish Audio](https://fish.audio)
+
+## ローカル推論のクイックスタート
+
+[inference.ipynb](/inference.ipynb)
+
+## ビデオ
+
+#### V1.4 デモビデオ: https://www.bilibili.com/video/BV1pu46eVEk7
+
+#### V1.2 デモビデオ: https://www.bilibili.com/video/BV1wz421B71D
+
+#### V1.1 デモビデオ: https://www.bilibili.com/video/BV1zJ4m1K7cj
+
+## ドキュメント
+
+- [英語](https://speech.fish.audio/)
+- [中文](https://speech.fish.audio/zh/)
+- [日本語](https://speech.fish.audio/ja/)
+- [ポルトガル語 (ブラジル)](https://speech.fish.audio/pt/)
+
+## サンプル (2024/10/02 V1.4)
+
+- [英語](https://speech.fish.audio/samples/)
+- [中文](https://speech.fish.audio/zh/samples/)
+- [日本語](https://speech.fish.audio/ja/samples/)
+- [ポルトガル語 (ブラジル)](https://speech.fish.audio/pt/samples/)
+
+## クレジット
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## スポンサー
+
+
+
diff --git a/docs/README.ko.md b/docs/README.ko.md
new file mode 100644
index 0000000000000000000000000000000000000000..952ae5f2f1946b33d1f57ea42c3b4a645e039840
--- /dev/null
+++ b/docs/README.ko.md
@@ -0,0 +1,111 @@
+
+
Fish Speech
+
+[English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | **한국어**
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+이 코드베이스와 모든 모델은 CC-BY-NC-SA-4.0 라이선스에 따라 배포됩니다. 자세한 내용은 [LICENSE](LICENSE)를 참조하시길 바랍니다.
+
+---
+
+## 기능
+
+1. **Zero-shot & Few-shot TTS:** 10초에서 30초의 음성 샘플을 입력하여 고품질의 TTS 출력을 생성합니다. **자세한 가이드는 [모범 사례](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)를 참조하시길 바랍니다.**
+
+2. **다국어 및 교차 언어 지원:** 다국어 걱정 없이, 텍스트를 입력창에 복사하여 붙여넣기만 하면 됩니다. 현재 영어, 일본어, 한국어, 중국어, 프랑스어, 독일어, 아랍어, 스페인어를 지원합니다.
+
+3. **음소 의존성 제거:** 이 모델은 강력한 일반화 능력을 가지고 있으며, TTS가 음소에 의존하지 않습니다. 모든 언어 스크립트 텍스트를 손쉽게 처리할 수 있습니다.
+
+4. **높은 정확도:** 영어 텍스트 기준 5분 기준에서 단, 2%의 문자 오류율(CER)과 단어 오류율(WER)을 달성합니다.
+
+5. **빠른 속도:** fish-tech 가속을 통해 실시간 인자(RTF)는 Nvidia RTX 4060 노트북에서는 약 1:5, Nvidia RTX 4090에서는 1:15입니다.
+
+6. **웹 UI 추론:** Chrome, Firefox, Edge 등 다양한 브라우저에서 호환되는 Gradio 기반의 사용하기 쉬운 웹 UI를 제공합니다.
+
+7. **GUI 추론:** PyQt6 그래픽 인터페이스를 제공하여 API 서버와 원활하게 작동합니다. Linux, Windows 및 macOS를 지원합니다. [GUI 참조](https://github.com/AnyaCoder/fish-speech-gui).
+
+8. **배포 친화적:** Linux, Windows, macOS에서 네이티브로 지원되는 추론 서버를 쉽게 설정할 수 있어 속도 손실을 최소화합니다.
+
+## 면책 조항
+
+이 코드베이스의 불법적 사용에 대해 어떠한 책임도 지지 않습니다. DMCA 및 관련 법률에 대한 로컬 법률을 참조하십시오.
+
+## 온라인 데모
+
+[Fish Audio](https://fish.audio)
+
+## 로컬 추론을 위한 빠른 시작
+
+[inference.ipynb](/inference.ipynb)
+
+## 영상
+
+#### V1.4 데모 영상: [Youtube](https://www.youtube.com/watch?v=Ghc8cJdQyKQ)
+
+## 문서
+
+- [English](https://speech.fish.audio/)
+- [中文](https://speech.fish.audio/zh/)
+- [日本語](https://speech.fish.audio/ja/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/)
+- [한국어](https://speech.fish.audio/ko/)
+
+## Samples (2024/10/02 V1.4)
+
+- [English](https://speech.fish.audio/samples/)
+- [中文](https://speech.fish.audio/zh/samples/)
+- [日本語](https://speech.fish.audio/ja/samples/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/)
+- [한국어](https://speech.fish.audio/ko/samples/)
+
+## Credits
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## Sponsor
+
+
+
diff --git a/docs/README.pt-BR.md b/docs/README.pt-BR.md
new file mode 100644
index 0000000000000000000000000000000000000000..443617ce3025cc0a93857f4c064c987c478b3092
--- /dev/null
+++ b/docs/README.pt-BR.md
@@ -0,0 +1,114 @@
+
+
Fish Speech
+
+[English](../README.md) | [简体中文](README.zh.md) | **Portuguese** | [日本語](README.ja.md) | [한국어](README.ko.md)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+Este código-fonte e os modelos são publicados sob a licença CC-BY-NC-SA-4.0. Consulte [LICENSE](LICENSE) para mais detalhes.
+
+---
+
+## Funcionalidades
+
+1. **TTS Zero-shot & Few-shot**: Insira uma amostra vocal de 10 a 30 segundos para gerar saída de TTS de alta qualidade. **Para diretrizes detalhadas, veja [Melhores Práticas para Clonagem de Voz](https://docs.fish.audio/text-to-speech/voice-clone-best-practices).**
+
+2. **Suporte Multilíngue e Interlingual**: Basta copiar e colar o texto multilíngue na caixa de entrada—não se preocupe com o idioma. Atualmente suporta inglês, japonês, coreano, chinês, francês, alemão, árabe e espanhol.
+
+3. **Sem Dependência de Fonemas**: O modelo tem forte capacidade de generalização e não depende de fonemas para TTS. Ele pode lidar com textos em qualquer script de idioma.
+
+4. **Alta Precisão**: Alcança uma CER (Taxa de Erro de Caracteres) e WER (Taxa de Erro de Palavras) de cerca de 2% para textos de 5 minutos em inglês.
+
+5. **Rápido**: Com a aceleração fish-tech, o fator de tempo real é de aproximadamente 1:5 em um laptop Nvidia RTX 4060 e 1:15 em uma Nvidia RTX 4090.
+
+6. **Inferência WebUI**: Apresenta uma interface de usuário web baseada em Gradio, fácil de usar e compatível com navegadores como Chrome, Firefox e Edge.
+
+7. **Inferência GUI**: Oferece uma interface gráfica PyQt6 que funciona perfeitamente com o servidor API. Suporta Linux, Windows e macOS. [Veja o GUI](https://github.com/AnyaCoder/fish-speech-gui).
+
+8. **Fácil de Implantar**: Configura facilmente um servidor de inferência com suporte nativo para Linux, Windows e macOS, minimizando a perda de velocidade.
+
+## Isenção de Responsabilidade
+
+Não nos responsabilizamos por qualquer uso ilegal do código-fonte. Consulte as leis locais sobre DMCA (Digital Millennium Copyright Act) e outras leis relevantes em sua região.
+
+## Demonstração Online
+
+[Fish Audio](https://fish.audio)
+
+## Início Rápido de Inferência Local
+
+[inference.ipynb](/inference.ipynb)
+
+## Vídeos
+
+#### 1.4 Introdução: https://www.bilibili.com/video/BV1pu46eVEk7
+
+#### 1.2 Introdução: https://www.bilibili.com/video/BV1wz421B71D
+
+#### 1.1 Apresentação Técnica: https://www.bilibili.com/video/BV1zJ4m1K7cj
+
+## Documentação
+
+- [Inglês](https://speech.fish.audio/)
+- [Chinês](https://speech.fish.audio/zh/)
+- [Japonês](https://speech.fish.audio/ja/)
+- [Português (Brasil)](https://speech.fish.audio/pt/)
+
+## Exemplos
+
+- [Inglês](https://speech.fish.audio/samples/)
+- [Chinês](https://speech.fish.audio/zh/samples/)
+- [Japonês](https://speech.fish.audio/ja/samples/)
+- [Português (Brasil)](https://speech.fish.audio/pt/samples/)
+
+## Agradecimentos
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## Patrocinadores
+
+
+
diff --git a/docs/README.zh.md b/docs/README.zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..ee09a64262cd95df2c8b985ae5d5582dd84e9cac
--- /dev/null
+++ b/docs/README.zh.md
@@ -0,0 +1,109 @@
+
+
Fish Speech
+
+[English](../README.md) | **简体中文** | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | [한국어](README.ko.md)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+此代码库及模型根据 CC-BY-NC-SA-4.0 许可证发布。请参阅 [LICENSE](LICENSE) 了解更多细节.
+
+---
+
+## 特性
+
+1. **零样本 & 小样本 TTS**:输入 10 到 30 秒的声音样本即可生成高质量的 TTS 输出。**详见 [语音克隆最佳实践指南](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)。**
+2. **多语言 & 跨语言支持**:只需复制并粘贴多语言文本到输入框中,无需担心语言问题。目前支持英语、日语、韩语、中文、法语、德语、阿拉伯语和西班牙语。
+3. **无音素依赖**:模型具备强大的泛化能力,不依赖音素进行 TTS,能够处理任何文字表示的语言。
+4. **高准确率**:在 5 分钟的英文文本上,达到了约 2% 的 CER(字符错误率)和 WER(词错误率)。
+5. **快速**:通过 fish-tech 加速,在 Nvidia RTX 4060 笔记本上的实时因子约为 1:5,在 Nvidia RTX 4090 上约为 1:15。
+6. **WebUI 推理**:提供易于使用的基于 Gradio 的网页用户界面,兼容 Chrome、Firefox、Edge 等浏览器。
+7. **GUI 推理**:提供 PyQt6 图形界面,与 API 服务器无缝协作。支持 Linux、Windows 和 macOS。[查看 GUI](https://github.com/AnyaCoder/fish-speech-gui)。
+8. **易于部署**:轻松设置推理服务器,原生支持 Linux、Windows 和 macOS,最大程度减少速度损失。
+
+## 免责声明
+
+我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
+
+## 在线 DEMO
+
+[Fish Audio](https://fish.audio)
+
+## 快速开始本地推理
+
+[inference.ipynb](/inference.ipynb)
+
+## 视频
+
+#### 1.4 介绍: https://www.bilibili.com/video/BV1pu46eVEk7
+
+#### 1.2 介绍: https://www.bilibili.com/video/BV1wz421B71D
+
+#### 1.1 介绍: https://www.bilibili.com/video/BV1zJ4m1K7cj
+
+## 文档
+
+- [English](https://speech.fish.audio/)
+- [中文](https://speech.fish.audio/zh/)
+- [日本語](https://speech.fish.audio/ja/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/)
+
+## 例子 (2024/10/02 V1.4)
+
+- [English](https://speech.fish.audio/samples/)
+- [中文](https://speech.fish.audio/zh/samples/)
+- [日本語](https://speech.fish.audio/ja/samples/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/)
+
+## 鸣谢
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## 赞助
+
+
+
diff --git a/docs/assets/figs/VS_1.jpg b/docs/assets/figs/VS_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..41a3f69992edcbbaa85a21695bdc33ff81dc10d6
Binary files /dev/null and b/docs/assets/figs/VS_1.jpg differ
diff --git a/docs/assets/figs/VS_1_pt-BR.png b/docs/assets/figs/VS_1_pt-BR.png
new file mode 100644
index 0000000000000000000000000000000000000000..d7cf5c85cb1cf98d9c716d03575eb0c74d53d572
Binary files /dev/null and b/docs/assets/figs/VS_1_pt-BR.png differ
diff --git a/docs/assets/figs/agent_gradio.png b/docs/assets/figs/agent_gradio.png
new file mode 100644
index 0000000000000000000000000000000000000000..02041bf6caa02f8c598b16bd8b495ef030dc3134
Binary files /dev/null and b/docs/assets/figs/agent_gradio.png differ
diff --git a/docs/assets/figs/diagram.png b/docs/assets/figs/diagram.png
new file mode 100644
index 0000000000000000000000000000000000000000..761b012f0a38ca6effc99eeb3bacfbfe11ffece0
Binary files /dev/null and b/docs/assets/figs/diagram.png differ
diff --git a/docs/assets/figs/diagrama.png b/docs/assets/figs/diagrama.png
new file mode 100644
index 0000000000000000000000000000000000000000..140f926ad9dc3e3e494872f1ca7b7e3f24994c3b
Binary files /dev/null and b/docs/assets/figs/diagrama.png differ
diff --git a/docs/assets/figs/logo-circle.png b/docs/assets/figs/logo-circle.png
new file mode 100644
index 0000000000000000000000000000000000000000..acfa4e3703e74909e4793020c5f3494f03dcb212
Binary files /dev/null and b/docs/assets/figs/logo-circle.png differ
diff --git a/docs/en/finetune.md b/docs/en/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..bf04086bfe0af570b24f32401f22294a5cd92a0a
--- /dev/null
+++ b/docs/en/finetune.md
@@ -0,0 +1,128 @@
+# Fine-tuning
+
+Obviously, when you opened this page, you were not satisfied with the performance of the few-shot pre-trained model. You want to fine-tune a model to improve its performance on your dataset.
+
+In current version, you only need to finetune the 'LLAMA' part.
+
+## Fine-tuning LLAMA
+### 1. Prepare the dataset
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+You need to convert your dataset into the above format and place it under `data`. The audio file can have the extensions `.mp3`, `.wav`, or `.flac`, and the annotation file should have the extensions `.lab`.
+
+!!! info "Dataset Format"
+ The `.lab` annotation file only needs to contain the transcription of the audio, with no special formatting required. For example, if `hi.mp3` says "Hello, goodbye," then the `hi.lab` file would contain a single line of text: "Hello, goodbye."
+
+!!! warning
+ It's recommended to apply loudness normalization to the dataset. You can use [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) to do this.
+
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+
+### 2. Batch extraction of semantic tokens
+
+Make sure you have downloaded the VQGAN weights. If not, run the following command:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+You can then run the following command to extract semantic tokens:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+!!! note
+ You can adjust `--num-workers` and `--batch-size` to increase extraction speed, but please make sure not to exceed your GPU memory limit.
+ For the VITS format, you can specify a file list using `--filelist xxx.list`.
+
+This command will create `.npy` files in the `data` directory, as shown below:
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. Pack the dataset into protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+After the command finishes executing, you should see the `quantized-dataset-ft.protos` file in the `data` directory.
+
+### 4. Finally, fine-tuning with LoRA
+
+Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+Finally, you can start the fine-tuning by running the following command:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ You can modify the training parameters such as `batch_size`, `gradient_accumulation_steps`, etc. to fit your GPU memory by modifying `fish_speech/configs/text2semantic_finetune.yaml`.
+
+!!! note
+ For Windows users, you can use `trainer.strategy.process_group_backend=gloo` to avoid `nccl` issues.
+
+After training is complete, you can refer to the [inference](inference.md) section to generate speech.
+
+!!! info
+ By default, the model will only learn the speaker's speech patterns and not the timbre. You still need to use prompts to ensure timbre stability.
+ If you want to learn the timbre, you can increase the number of training steps, but this may lead to overfitting.
+
+After training, you need to convert the LoRA weights to regular weights before performing inference.
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.5 \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.5-yth-lora/
+```
+!!! note
+ You may also try other checkpoints. We suggest using the earliest checkpoint that meets your requirements, as they often perform better on out-of-distribution (OOD) data.
diff --git a/docs/en/index.md b/docs/en/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..5e5308f5e7795cc1eff6ac95e7fec86616969a88
--- /dev/null
+++ b/docs/en/index.md
@@ -0,0 +1,215 @@
+# Introduction
+
+
+
+!!! warning
+ We assume no responsibility for any illegal use of the codebase. Please refer to the local laws regarding DMCA (Digital Millennium Copyright Act) and other relevant laws in your area.
+ This codebase and all models are released under the CC-BY-NC-SA-4.0 license.
+
+
+
+
+
+## Requirements
+
+- GPU Memory: 4GB (for inference), 8GB (for fine-tuning)
+- System: Linux, Windows
+
+## Windows Setup
+
+Professional Windows users may consider using WSL2 or Docker to run the codebase.
+
+```bash
+# Create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# Install pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+# Install fish-speech
+pip3 install -e .
+
+# (Enable acceleration) Install triton-windows
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
+Non-professional Windows users can consider the following basic methods to run the project without a Linux environment (with model compilation capabilities, i.e., `torch.compile`):
+
+1. Extract the project package.
+2. Click `install_env.bat` to install the environment.
+3. If you want to enable compilation acceleration, follow this step:
+ 1. Download the LLVM compiler from the following links:
+ - [LLVM-17.0.6 (Official Site Download)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6 (Mirror Site Download)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - After downloading `LLVM-17.0.6-win64.exe`, double-click to install, select an appropriate installation location, and most importantly, check the `Add Path to Current User` option to add the environment variable.
+ - Confirm that the installation is complete.
+ 2. Download and install the Microsoft Visual C++ Redistributable to solve potential .dll missing issues:
+ - [MSVC++ 14.40.33810.0 Download](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. Download and install Visual Studio Community Edition to get MSVC++ build tools and resolve LLVM's header file dependencies:
+ - [Visual Studio Download](https://visualstudio.microsoft.com/zh-hans/downloads/)
+ - After installing Visual Studio Installer, download Visual Studio Community 2022.
+ - As shown below, click the `Modify` button and find the `Desktop development with C++` option to select and download.
+ 4. Download and install [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
+4. Double-click `start.bat` to open the training inference WebUI management interface. If needed, you can modify the `API_FLAGS` as prompted below.
+
+!!! info "Optional"
+
+ Want to start the inference WebUI?
+
+ Edit the `API_FLAGS.txt` file in the project root directory and modify the first three lines as follows:
+ ```
+ --infer
+ # --api
+ # --listen ...
+ ...
+ ```
+
+!!! info "Optional"
+
+ Want to start the API server?
+
+ Edit the `API_FLAGS.txt` file in the project root directory and modify the first three lines as follows:
+
+ ```
+ # --infer
+ --api
+ --listen ...
+ ...
+ ```
+
+!!! info "Optional"
+
+ Double-click `run_cmd.bat` to enter the conda/python command line environment of this project.
+
+## Linux Setup
+
+See [pyproject.toml](../../pyproject.toml) for details.
+```bash
+# Create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# Install pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+
+# (Ubuntu / Debian User) Install sox + ffmpeg
+apt install libsox-dev ffmpeg
+
+# (Ubuntu / Debian User) Install pyaudio
+apt install build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0
+
+# Install fish-speech
+pip3 install -e .[stable]
+```
+
+## macos setup
+
+If you want to perform inference on MPS, please add the `--device mps` flag.
+Please refer to [this PR](https://github.com/fishaudio/fish-speech/pull/461#issuecomment-2284277772) for a comparison of inference speeds.
+
+!!! warning
+ The `compile` option is not officially supported on Apple Silicon devices, so there is no guarantee that inference speed will improve.
+
+```bash
+# create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+# install pytorch
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+# install fish-speech
+pip install -e .[stable]
+```
+
+## Docker Setup
+
+1. Install NVIDIA Container Toolkit:
+
+ To use GPU for model training and inference in Docker, you need to install NVIDIA Container Toolkit:
+
+ For Ubuntu users:
+
+ ```bash
+ # Add repository
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
+ && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+ # Install nvidia-container-toolkit
+ sudo apt-get update
+ sudo apt-get install -y nvidia-container-toolkit
+ # Restart Docker service
+ sudo systemctl restart docker
+ ```
+
+ For users of other Linux distributions, please refer to: [NVIDIA Container Toolkit Install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
+
+2. Pull and run the fish-speech image
+
+ ```shell
+ # Pull the image
+ docker pull fishaudio/fish-speech:latest-dev
+ # Run the image
+ docker run -it \
+ --name fish-speech \
+ --gpus all \
+ -p 7860:7860 \
+ fishaudio/fish-speech:latest-dev \
+ zsh
+ # If you need to use a different port, please modify the -p parameter to YourPort:7860
+ ```
+
+3. Download model dependencies
+
+ Make sure you are in the terminal inside the docker container, then download the required `vqgan` and `llama` models from our huggingface repository.
+
+ ```bash
+ huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+ ```
+
+4. Configure environment variables and access WebUI
+
+ In the terminal inside the docker container, enter `export GRADIO_SERVER_NAME="0.0.0.0"` to allow external access to the gradio service inside docker.
+ Then in the terminal inside the docker container, enter `python tools/run_webui.py` to start the WebUI service.
+
+ If you're using WSL or MacOS, visit [http://localhost:7860](http://localhost:7860) to open the WebUI interface.
+
+ If it's deployed on a server, replace localhost with your server's IP.
+
+## Changelog
+
+- 2024/09/10: Updated Fish-Speech to 1.4 version, with an increase in dataset size and a change in the quantizer's n_groups from 4 to 8.
+- 2024/07/02: Updated Fish-Speech to 1.2 version, remove VITS Decoder, and greatly enhanced zero-shot ability.
+- 2024/05/10: Updated Fish-Speech to 1.1 version, implement VITS decoder to reduce WER and improve timbre similarity.
+- 2024/04/22: Finished Fish-Speech 1.0 version, significantly modified VQGAN and LLAMA models.
+- 2023/12/28: Added `lora` fine-tuning support.
+- 2023/12/27: Add `gradient checkpointing`, `causual sampling`, and `flash-attn` support.
+- 2023/12/19: Updated webui and HTTP API.
+- 2023/12/18: Updated fine-tuning documentation and related examples.
+- 2023/12/17: Updated `text2semantic` model, supporting phoneme-free mode.
+- 2023/12/13: Beta version released, includes VQGAN model and a language model based on LLAMA (phoneme support only).
+
+## Acknowledgements
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/en/inference.md b/docs/en/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..d3e055639bd9ab1ddad664f5c4f9980bc8acec6e
--- /dev/null
+++ b/docs/en/inference.md
@@ -0,0 +1,135 @@
+# Inference
+
+Inference support command line, HTTP API and web UI.
+
+!!! note
+ Overall, reasoning consists of several parts:
+
+ 1. Encode a given ~10 seconds of voice using VQGAN.
+ 2. Input the encoded semantic tokens and the corresponding text into the language model as an example.
+ 3. Given a new piece of text, let the model generate the corresponding semantic tokens.
+ 4. Input the generated semantic tokens into VITS / VQGAN to decode and generate the corresponding voice.
+
+## Command Line Inference
+
+Download the required `vqgan` and `llama` models from our Hugging Face repository.
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+### 1. Generate prompt from voice:
+
+!!! note
+ If you plan to let the model randomly choose a voice timbre, you can skip this step.
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+You should get a `fake.npy` file.
+
+### 2. Generate semantic tokens from text:
+
+```bash
+python tools/llama/generate.py \
+ --text "The text you want to convert" \
+ --prompt-text "Your reference text" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.5" \
+ --num-samples 2 \
+ --compile
+```
+
+This command will create a `codes_N` file in the working directory, where N is an integer starting from 0.
+
+!!! note
+ You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~500 tokens/second).
+ Correspondingly, if you do not plan to use acceleration, you can comment out the `--compile` parameter.
+
+!!! info
+ For GPUs that do not support bf16, you may need to use the `--half` parameter.
+
+### 3. Generate vocals from semantic tokens:
+
+#### VQGAN Decoder
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+## HTTP API Inference
+
+We provide a HTTP API for inference. You can use the following command to start the server:
+
+```bash
+python -m tools.api_server \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+> If you want to speed up inference, you can add the `--compile` parameter.
+
+After that, you can view and test the API at http://127.0.0.1:8080/.
+
+Below is an example of sending a request using `tools/api_client.py`.
+
+```bash
+python -m tools.api_client \
+ --text "Text to be input" \
+ --reference_audio "Path to reference audio" \
+ --reference_text "Text content of the reference audio" \
+ --streaming True
+```
+
+The above command indicates synthesizing the desired audio according to the reference audio information and returning it in a streaming manner.
+
+The following example demonstrates that you can use **multiple** reference audio paths and reference audio texts at once. Separate them with spaces in the command.
+
+```bash
+python -m tools.api_client \
+ --text "Text to input" \
+ --reference_audio "reference audio path1" "reference audio path2" \
+ --reference_text "reference audio text1" "reference audio text2"\
+ --streaming False \
+ --output "generated" \
+ --format "mp3"
+```
+
+The above command synthesizes the desired `MP3` format audio based on the information from multiple reference audios and saves it as `generated.mp3` in the current directory.
+
+You can also use `--reference_id` (only one can be used) instead of `--reference-audio` and `--reference_text`, provided that you create a `references/` folder in the project root directory, which contains any audio and annotation text.
+The currently supported reference audio has a maximum total duration of 90 seconds.
+
+
+!!! info
+ To learn more about available parameters, you can use the command `python -m tools.api_client -h`
+
+## GUI Inference
+[Download client](https://github.com/AnyaCoder/fish-speech-gui/releases)
+
+## WebUI Inference
+
+You can start the WebUI using the following command:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+> If you want to speed up inference, you can add the `--compile` parameter.
+
+!!! note
+ You can save the label file and reference audio file in advance to the `references` folder in the main directory (which you need to create yourself), so that you can directly call them in the WebUI.
+
+!!! note
+ You can use Gradio environment variables, such as `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` to configure WebUI.
+
+Enjoy!
diff --git a/docs/en/samples.md b/docs/en/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..93f378407a4ef1b69e88baa33cc170f3571217bf
--- /dev/null
+++ b/docs/en/samples.md
@@ -0,0 +1,137 @@
+# Samples
+
+ver 1.4
+
+## Credits
+Special thanks to [Seed-TTS (2024)](https://bytedancespeech.github.io/seedtts_tech_report/) for providing the evaluation data for demonstration.
+
+All prompt audio is from the Seed-TTS effect demo page, and all generated audio is from the first generation of fish-speech version 1.4.
+
+## Zero-shot In-context Learning
+
+
+
+ Language
+ Prompt
+ Same Language Generation
+ Cross-linugal Generation
+
+
+
+
+ EN
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.
+ Your browser does not support the audio element. 顿时,气氛变得沉郁起来。乍看之下,一切的困扰仿佛都围绕在我身边。我皱着眉头,感受着那份压力,但我知道我不能放弃,不能认输。于是,我深吸一口气,心底的声音告诉我:“无论如何,都要冷静下来,重新开始。”
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. Dealing with family secrets is never easy. Yet, sometimes, omission is a form of protection, intending to safeguard some from the harsh truths. One day, I hope you understand the reasons behind my actions. Until then, Anna, please, bear with me.
+ Your browser does not support the audio element. 处理家庭秘密从来都不是一件容易的事。然而,有时候,隐瞒是一种保护形式,旨在保护一些人免受残酷的真相伤害。有一天,我希望你能理解我行为背后的原因。在那之前,安娜,请容忍我。
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. The combinations of different textures and flavors create a perfect harmony. The succulence of the steak, the tartness of the cranberries, the crunch of pine nuts, and creaminess of blue cheese make it a truly delectable delight. Enjoy your culinary adventure!
+ Your browser does not support the audio element. 听着你的话,我心里五味杂陈。虽然我愿意一直在你身边,承担一切不幸,但我知道只有让你自己面对,才能真正让你变得更强大。所以,你要记得,无论面对何种困难,都请你坚强,我会在心里一直支持你的。
+
+
+ ZH
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
+ Your browser does not support the audio element. Suddenly, there was a burst of laughter beside me. I looked at them, stood up straight with high spirit, shook the slightly fleshy arms, and smiled lightly, saying, "The flesh on my body is to hide my bursting charm. Otherwise, wouldn't it scare you?"
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 他闭上眼睛,期望这一切都能过去。然而,当他再次睁开眼睛,眼前的景象让他不禁倒吸一口气。雾气中出现的禁闭岛,陌生又熟悉,充满未知的危险。他握紧拳头,心知他的生活即将发生翻天覆地的改变。
+ Your browser does not support the audio element. He closed his eyes, expecting that all of this could pass. However, when he opened his eyes again, the sight in front of him made him couldn't help but take a deep breath. The closed island that appeared in the fog, strange and familiar, was full of unknown dangers. He tightened his fist, knowing that his life was about to undergo earth-shaking changes.
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 顿时,气氛变得沉郁起来。乍看之下,一切的困扰仿佛都围绕在我身边。我皱着眉头,感受着那份压力,但我知道我不能放弃,不能认输。于是,我深吸一口气,心底的声音告诉我:“无论如何,都要冷静下来,重新开始。”
+ Your browser does not support the audio element. Suddenly, the atmosphere became gloomy. At first glance, all the troubles seemed to surround me. I frowned, feeling that pressure, but I know I can't give up, can't admit defeat. So, I took a deep breath, and the voice in my heart told me, "Anyway, must calm down and start again."
+
+
+
+
+
+
+## Speaker Fine-tune
+
+
+
+
+
+ Text
+ Generated
+
+
+
+
+ Speaker1
+ 好呀,哈哈哈哈哈,喜欢笑的人运气都不会差哦,希望你每天笑口常开~
+ Your browser does not support the audio element.
+
+
+ 哇!恭喜你中了大乐透,八百万可真不少呢!有什么特别的计划或想法吗?
+ Your browser does not support the audio element.
+
+
+ 哼,你这么问是想请本小姐吃饭吗?如果对象是你的话,那也不是不可以。
+ Your browser does not support the audio element.
+
+
+ Speaker2
+ 是呀,他还想换个地球仪哈哈哈,看来给你积累了一些快乐值了,你还想不想再听一个其他的笑话呀?
+ Your browser does not support the audio element.
+
+
+ 嘿嘿,你是不是也想拥有甜甜的恋爱呢?《微微一笑很倾城》是你的不二选择,男女主是校花校草类型,他们通过游戏结识,再到两人见面,全程没有一点误会,真的齁甜,想想都忍不住“姨妈笑”~
+ Your browser does not support the audio element.
+
+
+ 小傻瓜,嗯……算是个很可爱很亲切的名字,有点“独特”哦,不过我有些好奇,你为什么会给我选这个昵称呢?
+ Your browser does not support the audio element.
+
+
+
+
+
+## Content Editing
+
+
+
+ Language
+ Original Text
+ Original Audio
+ Target Text
+ Edited Audio
+
+
+
+ EN
+ They can't order me to stop dreaming. If you dream a thing more than once, it's sure to come true. Have faith in your dreams, and someday your rainbow will come shining through.
+ Your browser does not support the audio element.
+ They can't require me to stop imagining. If you envision a thing more than once, it's bound to come about . Have trust in your visions , and someday your radiance will come beaming through.
+ Your browser does not support the audio element.
+
+
+ Are you familiar with it? Slice the steak and place the strips on top, then garnish with the dried cranberries, pine nuts, and blue cheese. I wonder how people rationalise the decision?
+ Your browser does not support the audio element.
+ Are you acquainted with it? Cut the pork and place the strips on top, then garnish with the dried cherries, almonds, and feta cheese. I query how people justify the choice?
+ Your browser does not support the audio element.
+
+
+ ZH
+ 自古以来,庸君最怕党政了,可圣君他就不怕,不但不怕,反能利用。要我说,你就让明珠索额图互相争宠,只要你心里明白,左右逢源,你就能立于不败之地。
+ Your browser does not support the audio element.
+ 从古至今 ,庸君最怕朝纲了 ,可明 君他就不怕,不但不怕,反能借助 。要我说,你就让李四张三 互相争宠,只要你心里清楚 ,左右周旋 ,你就能处 于不败之境 。
+ Your browser does not support the audio element.
+
+
+ 对,这就是我,万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。
+ Your browser does not support the audio element.
+ 对,这就是我,众人尊崇 的太白金星 ,虽然有点娃娃脸 ,但也遮 不住我迷人 的魅力。
+ Your browser does not support the audio element.
+
+
+
diff --git a/docs/en/start_agent.md b/docs/en/start_agent.md
new file mode 100644
index 0000000000000000000000000000000000000000..b0e32a301f6b4e5c2c81223922e58349acc4e69b
--- /dev/null
+++ b/docs/en/start_agent.md
@@ -0,0 +1,77 @@
+# Start Agent
+
+## Requirements
+
+- GPU memory: At least 8GB(under quanization), 16GB or more is recommanded.
+- Disk usage: 10GB
+
+## Download Model
+
+You can get the model by:
+
+```bash
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
+```
+
+Put them in the 'checkpoints' folder.
+
+You also need the fish-speech model which you can download instructed by [inference](inference.md).
+
+So there will be 2 folder in the checkpoints.
+
+The `checkpoints/fish-speech-1.4` and `checkpoints/fish-agent-v0.1-3b`
+
+## Environment Prepare
+
+If you already have Fish-speech, you can directly use by adding the follow instruction:
+```bash
+pip install cachetools
+```
+
+!!! note
+ Please use the Python version below 3.12 for compile.
+
+If you don't have, please use the below commands to build yout environment:
+
+```bash
+sudo apt-get install portaudio19-dev
+
+pip install -e .[stable]
+```
+
+## Launch The Agent Demo.
+
+To build fish-agent, please use the command below under the main folder:
+
+```bash
+python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+```
+
+The `--compile` args only support Python < 3.12 , which will greatly speed up the token generation.
+
+It won't compile at once (remember).
+
+Then open another terminal and use the command:
+
+```bash
+python -m tools.e2e_webui
+```
+
+This will create a Gradio WebUI on the device.
+
+When you first use the model, it will come to compile (if the `--compile` is True) for a short time, so please wait with patience.
+
+## Gradio Webui
+
+
+
+
+Have a good time!
+
+## Performance
+
+Under our test, a 4060 laptop just barely runs, but is very stretched, which is only about 8 tokens/s. The 4090 is around 95 tokens/s under compile, which is what we recommend.
+
+# About Agent
+
+The demo is an early alpha test version, the inference speed needs to be optimised, and there are a lot of bugs waiting to be fixed. If you've found a bug or want to fix it, we'd be very happy to receive an issue or a pull request.
diff --git a/docs/ja/finetune.md b/docs/ja/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..cfc049b184b2b4f55815a5852ae782c902a04ceb
--- /dev/null
+++ b/docs/ja/finetune.md
@@ -0,0 +1,128 @@
+# 微調整
+
+明らかに、このページを開いたとき、few-shot 事前トレーニングモデルのパフォーマンスに満足していなかったことでしょう。データセット上でのパフォーマンスを向上させるためにモデルを微調整したいと考えています。
+
+現在のバージョンでは、「LLAMA」部分のみを微調整する必要があります。
+
+## LLAMAの微調整
+### 1. データセットの準備
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+データセットを上記の形式に変換し、「data」ディレクトリに配置する必要があります。音声ファイルの拡張子は「.mp3」、「.wav」、または「.flac」にすることができ、注釈ファイルの拡張子は「.lab」にする必要があります。
+
+!!! info
+ 標準ファイル `.lab` には、音声の転写テキストのみを含め、特別なフォーマットは必要ありません。例えば、`hi.mp3` で「こんにちは、さようなら」と言っている場合、`hi.lab` ファイルには「こんにちは、さようなら」という一行のテキストを含めるだけです。
+
+!!! warning
+ データセットにラウドネス正規化を適用することをお勧めします。これを行うには、[fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) を使用できます。
+
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+
+### 2. セマンティックトークンのバッチ抽出
+
+VQGANの重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+次に、次のコマンドを実行してセマンティックトークンを抽出できます。
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+!!! note
+ `--num-workers` と `--batch-size` を調整して抽出速度を上げることができますが、GPUメモリの制限を超えないようにしてください。
+ VITS形式の場合、`--filelist xxx.list` を使用してファイルリストを指定できます。
+
+このコマンドは、`data`ディレクトリに`.npy`ファイルを作成します。以下のように表示されます。
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. データセットをprotobufにパックする
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+コマンドの実行が完了すると、`data`ディレクトリに`quantized-dataset-ft.protos`ファイルが表示されます。
+
+### 4. 最後に、LoRAを使用して微調整する
+
+同様に、`LLAMA`の重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+最後に、次のコマンドを実行して微調整を開始できます。
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ `fish_speech/configs/text2semantic_finetune.yaml` を変更して、`batch_size`、`gradient_accumulation_steps` などのトレーニングパラメータを変更し、GPUメモリに適合させることができます。
+
+!!! note
+ Windowsユーザーの場合、`trainer.strategy.process_group_backend=gloo` を使用して `nccl` の問題を回避できます。
+
+トレーニングが完了したら、[推論](inference.md)セクションを参照し、音声を生成します。
+
+!!! info
+ デフォルトでは、モデルは話者の発話パターンのみを学習し、音色は学習しません。音色の安定性を確保するためにプロンプトを使用する必要があります。
+ 音色を学習したい場合は、トレーニングステップ数を増やすことができますが、これにより過学習が発生する可能性があります。
+
+トレーニングが完了したら、推論を行う前にLoRAの重みを通常の重みに変換する必要があります。
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.5 \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.5-yth-lora/
+```
+!!! note
+ 他のチェックポイントを試すこともできます。要件を満たす最も早いチェックポイントを使用することをお勧めします。これらは通常、分布外(OOD)データでより良いパフォーマンスを発揮します。
diff --git a/docs/ja/index.md b/docs/ja/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..a7889055956abbd3351fd3e83432b08425e4d50f
--- /dev/null
+++ b/docs/ja/index.md
@@ -0,0 +1,214 @@
+# Fish Speech の紹介
+
+
+
+!!! warning
+ 私たちは、コードベースの違法な使用について一切の責任を負いません。お住まいの地域の DMCA(デジタルミレニアム著作権法)およびその他の関連法を参照してください。
+ このコードベースとモデルは、CC-BY-NC-SA-4.0 ライセンス下でリリースされています。
+
+
+
+
+
+## 要件
+
+- GPU メモリ: 4GB(推論用)、8GB(ファインチューニング用)
+- システム: Linux、Windows
+
+## Windowsセットアップ
+
+プロフェッショナルなWindowsユーザーは、WSL2またはDockerを使用してコードベースを実行することを検討してください。
+
+```bash
+# Python 3.10の仮想環境を作成(virtualenvも使用可能)
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# PyTorchをインストール
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+# fish-speechをインストール
+pip3 install -e .
+
+# (アクセラレーションを有効にする) triton-windowsをインストール
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
+非プロフェッショナルなWindowsユーザーは、Linux環境なしでプロジェクトを実行するための以下の基本的な方法を検討できます(モデルコンパイル機能、つまり`torch.compile`を使用可能):
+
+1. プロジェクトパッケージを解凍する。
+2. `install_env.bat`をクリックして環境をインストールする。
+3. コンパイルアクセラレーションを有効にしたい場合は、次のステップに従ってください:
+ 1. 以下のリンクからLLVMコンパイラをダウンロード:
+ - [LLVM-17.0.6(公式サイトのダウンロード)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6(ミラーサイトのダウンロード)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - `LLVM-17.0.6-win64.exe`をダウンロードした後、ダブルクリックしてインストールし、適切なインストール場所を選択し、最も重要なのは`Add Path to Current User`オプションを選択して環境変数を追加することです。
+ - インストールが完了したことを確認する。
+ 2. 欠落している .dll の問題を解決するため、Microsoft Visual C++ Redistributable をダウンロードしてインストールする:
+ - [MSVC++ 14.40.33810.0 ダウンロード](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. Visual Studio Community Editionをダウンロードして、MSVC++ビルドツールを取得し、LLVMのヘッダーファイルの依存関係を解決する:
+ - [Visual Studio ダウンロード](https://visualstudio.microsoft.com/ja/downloads/)
+ - Visual Studio Installerをインストールした後、Visual Studio Community 2022をダウンロード。
+ - 下記のように、`Modify`ボタンをクリックし、`C++によるデスクトップ開発`オプションを選択してダウンロード。
+ -
+ 4. [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)をダウンロードしてインストールする。
+4. `start.bat`をダブルクリックして、トレーニング推論WebUI管理インターフェースを開きます。必要に応じて、以下に示すように`API_FLAGS`を修正できます。
+
+
+!!! info "オプション"
+ 推論WebUIを起動しますか?
+ プロジェクトのルートディレクトリにある `API_FLAGS.txt` ファイルを編集し、最初の3行を次のように変更します:
+ ```
+ --infer
+ # --api
+ # --listen ...
+ ...
+ ```
+
+!!! info "オプション"
+ APIサーバーを起動しますか?
+ プロジェクトのルートディレクトリにある `API_FLAGS.txt` ファイルを編集し、最初の3行を次のように変更します:
+ ```
+ # --infer
+ --api
+ --listen ...
+ ...
+ ```
+
+!!! info "オプション"
+ `run_cmd.bat` をダブルクリックして、このプロジェクトの conda/python コマンドライン環境に入ります。
+
+
+
+## Linux セットアップ
+
+詳細については、[pyproject.toml](../../pyproject.toml) を参照してください。
+```bash
+# python 3.10の仮想環境を作成します。virtualenvも使用できます。
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# pytorchをインストールします。
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+
+# (Ubuntu / Debianユーザー) sox + ffmpegをインストールします。
+apt install libsox-dev ffmpeg
+
+# (Ubuntu / Debianユーザー) pyaudio をインストールします。
+apt install build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0
+
+# fish-speechをインストールします。
+pip3 install -e .[stable]
+
+```
+
+## macos setup
+
+推論をMPS上で行う場合は、`--device mps`フラグを追加してください。
+推論速度の比較は[こちらのPR](https://github.com/fishaudio/fish-speech/pull/461#issuecomment-2284277772)を参考にしてください。
+
+!!! warning
+ AppleSiliconのデバイスでは、compileオプションに正式に対応していませんので、推論速度が向上する保証はありません。
+
+```bash
+# create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+# install pytorch
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+# install fish-speech
+pip install -e .[stable]
+```
+
+## Docker セットアップ
+
+1. NVIDIA Container Toolkit のインストール:
+
+ Docker で GPU を使用してモデルのトレーニングと推論を行うには、NVIDIA Container Toolkit をインストールする必要があります:
+
+ Ubuntu ユーザーの場合:
+
+ ```bash
+ # リポジトリの追加
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
+ && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+ # nvidia-container-toolkit のインストール
+ sudo apt-get update
+ sudo apt-get install -y nvidia-container-toolkit
+ # Docker サービスの再起動
+ sudo systemctl restart docker
+ ```
+
+ 他の Linux ディストリビューションを使用している場合は、以下のインストールガイドを参照してください:[NVIDIA Container Toolkit Install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)。
+
+2. fish-speech イメージのプルと実行
+
+ ```shell
+ # イメージのプル
+ docker pull fishaudio/fish-speech:latest-dev
+ # イメージの実行
+ docker run -it \
+ --name fish-speech \
+ --gpus all \
+ -p 7860:7860 \
+ fishaudio/fish-speech:latest-dev \
+ zsh
+ # 他のポートを使用する場合は、-p パラメータを YourPort:7860 に変更してください
+ ```
+
+3. モデルの依存関係のダウンロード
+
+ Docker コンテナ内のターミナルにいることを確認し、huggingface リポジトリから必要な `vqgan` と `llama` モデルをダウンロードします。
+
+ ```bash
+ huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+ ```
+
+4. 環境変数の設定と WebUI へのアクセス
+
+ Docker コンテナ内のターミナルで、`export GRADIO_SERVER_NAME="0.0.0.0"` と入力して、外部から Docker 内の gradio サービスにアクセスできるようにします。
+ 次に、Docker コンテナ内のターミナルで `python tools/run_webui.py` と入力して WebUI サービスを起動します。
+
+ WSL または MacOS の場合は、[http://localhost:7860](http://localhost:7860) にアクセスして WebUI インターフェースを開くことができます。
+
+ サーバーにデプロイしている場合は、localhost をサーバーの IP に置き換えてください。
+
+## 変更履歴
+
+- 2024/09/10: Fish-Speech を Ver.1.4 に更新し、データセットのサイズを増加させ、quantizer n_groups を 4 から 8 に変更しました。
+- 2024/07/02: Fish-Speech を Ver.1.2 に更新し、VITS デコーダーを削除し、ゼロショット能力を大幅に強化しました。
+- 2024/05/10: Fish-Speech を Ver.1.1 に更新し、VITS デコーダーを実装して WER を減少させ、音色の類似性を向上させました。
+- 2024/04/22: Fish-Speech Ver.1.0 を完成させ、VQGAN および LLAMA モデルを大幅に修正しました。
+- 2023/12/28: `lora`微調整サポートを追加しました。
+- 2023/12/27: `gradient checkpointing`、`causual sampling`、および`flash-attn`サポートを追加しました。
+- 2023/12/19: webui および HTTP API を更新しました。
+- 2023/12/18: 微調整ドキュメントおよび関連例を更新しました。
+- 2023/12/17: `text2semantic`モデルを更新し、自由音素モードをサポートしました。
+- 2023/12/13: ベータ版をリリースし、VQGAN モデルおよび LLAMA に基づく言語モデル(音素のみサポート)を含みます。
+
+## 謝辞
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/ja/inference.md b/docs/ja/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..ed558c9d40f1f60eea32ed051f6ab7a7e16abe87
--- /dev/null
+++ b/docs/ja/inference.md
@@ -0,0 +1,114 @@
+# 推論
+
+推論は、コマンドライン、HTTP API、および Web UI をサポートしています。
+
+!!! note
+ 全体として、推論は次のいくつかの部分で構成されています:
+
+ 1. VQGANを使用して、与えられた約10秒の音声をエンコードします。
+ 2. エンコードされたセマンティックトークンと対応するテキストを例として言語モデルに入力します。
+ 3. 新しいテキストが与えられた場合、モデルに対応するセマンティックトークンを生成させます。
+ 4. 生成されたセマンティックトークンをVITS / VQGANに入力してデコードし、対応する音声を生成します。
+
+## コマンドライン推論
+
+必要な`vqgan`および`llama`モデルを Hugging Face リポジトリからダウンロードします。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+### 1. 音声からプロンプトを生成する:
+
+!!! note
+ モデルにランダムに音声の音色を選ばせる場合、このステップをスキップできます。
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+`fake.npy`ファイルが生成されるはずです。
+
+### 2. テキストからセマンティックトークンを生成する:
+
+```bash
+python tools/llama/generate.py \
+ --text "変換したいテキスト" \
+ --prompt-text "参照テキスト" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.5" \
+ --num-samples 2 \
+ --compile
+```
+
+このコマンドは、作業ディレクトリに`codes_N`ファイルを作成します。ここで、N は 0 から始まる整数です。
+
+!!! note
+ `--compile`を使用して CUDA カーネルを融合し、より高速な推論を実現することができます(約 30 トークン/秒 -> 約 500 トークン/秒)。
+ それに対応して、加速を使用しない場合は、`--compile`パラメータをコメントアウトできます。
+
+!!! info
+ bf16 をサポートしていない GPU の場合、`--half`パラメータを使用する必要があるかもしれません。
+
+### 3. セマンティックトークンから音声を生成する:
+
+#### VQGAN デコーダー
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+## HTTP API 推論
+
+推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます:
+
+```bash
+python -m tools.api_server \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+> 推論を高速化したい場合は、`--compile` パラメータを追加できます。
+
+その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。
+
+以下は、`tools/api_client.py` を使用してリクエストを送信する例です。
+
+```bash
+python -m tools.api_client \
+ --text "入力するテキスト" \
+ --reference_audio "参照音声へのパス" \
+ --reference_text "参照音声テキスト" \
+ --streaming True
+```
+
+上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。
+
+!!! info
+ 使用可能なパラメータの詳細については、コマンド` python -m tools.api_client -h `を使用してください
+
+## WebUI 推論
+
+次のコマンドを使用して WebUI を起動できます:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+> 推論を高速化したい場合は、`--compile` パラメータを追加できます。
+
+!!! note
+ ラベルファイルと参照音声ファイルをメインディレクトリの `references` フォルダ(自分で作成する必要があります)に事前に保存しておくことで、WebUI で直接呼び出すことができます。
+
+!!! note
+ Gradio 環境変数(`GRADIO_SHARE`、`GRADIO_SERVER_PORT`、`GRADIO_SERVER_NAME`など)を使用して WebUI を構成できます。
+
+お楽しみください!
diff --git a/docs/ja/samples.md b/docs/ja/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..13907736daed5961c66f56524bb0fa7068b76f3a
--- /dev/null
+++ b/docs/ja/samples.md
@@ -0,0 +1,225 @@
+# サンプル
+
+v1.4デモは[こちら](https://speech.fish.audio/samples/)に更新されています
+
+v1.2のサンプルは[Bilibili](https://www.bilibili.com/video/BV1wz421B71D/)で利用可能です。
+
+以下のサンプルはv1.1モデルからのものです。
+
+## 中国語の文1
+```
+人間灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。
+流入她所注视的世间,也流入她如湖水般澄澈的目光。
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ナヒーダ (原神)
+
+
+
+
+ 鍾離 (原神)
+
+
+
+
+ フリナ (原神)
+
+
+
+
+ ランダム話者1
+ -
+
+
+
+ ランダム話者2
+ -
+
+
+
+
+
+
+## 中国語の文2
+```
+你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么?
+我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊?
+你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊,
+搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗?
+一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊?
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ナヒーダ (原神)
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
+
+
+## 中国語の文3
+```
+大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练,
+我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。
+作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。
+```
+
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
+
+## 英語の文1
+
+```
+In the realm of advanced technology, the evolution of artificial intelligence stands as a
+monumental achievement. This dynamic field, constantly pushing the boundaries of what
+machines can do, has seen rapid growth and innovation. From deciphering complex data
+patterns to driving cars autonomously, AI's applications are vast and diverse.
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者1
+ -
+
+
+
+ ランダム話者2
+ -
+
+
+
+
+
+## 英語の文2
+```
+Hello everyone, I am an open-source text-to-speech model developed by
+Fish Audio. After training with 150,000 hours of data, I have become proficient
+in Chinese, Japanese, and English, and my language processing abilities
+are close to human level. My voice is capable of a wide range of expressions.
+As a model with only hundreds of millions of parameters, I believe community
+members can easily run and fine-tune me on their personal devices, allowing
+me to serve as your personal voice assistant.
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
+
+## 日本語の文1
+
+```
+先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を
+押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か
+ら自動運転車の操縦まで、AIの応用は広範囲に及びます。
+```
+
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者1
+ -
+
+
+
+ ランダム話者2
+ -
+
+
+
+
+
+## 日本語の文2
+```
+皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ
+キストから音声への変換モデルです。15万時間のデータトレーニングを経て、
+中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。
+声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ
+のメンバーが個人のデバイスで簡単に実行し、微調整することができると
+信じています。これにより、私を個人の音声アシスタントとして活用できます。
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
diff --git a/docs/ja/start_agent.md b/docs/ja/start_agent.md
new file mode 100644
index 0000000000000000000000000000000000000000..79b1b2c9ac6d310ed467612f15ce348d3ac16e35
--- /dev/null
+++ b/docs/ja/start_agent.md
@@ -0,0 +1,80 @@
+# エージェントの開始
+
+!!! note
+ もしあなたがネイティブ・スピーカーで、翻訳に問題があるとお感じでしたら、issueかpull requestをお送りください!
+
+## 要件
+
+- GPUメモリ: 最低8GB(量子化使用時)、16GB以上推奨
+- ディスク使用量: 10GB
+
+## モデルのダウンロード
+
+以下のコマンドでモデルを取得できます:
+
+```bash
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
+```
+
+これらを'checkpoints'フォルダに配置してください。
+
+また、[inference](inference.md)の手順に従ってfish-speechモデルもダウンロードする必要があります。
+
+checkpointsには2つのフォルダが必要です。
+
+`checkpoints/fish-speech-1.4`と`checkpoints/fish-agent-v0.1-3b`です。
+
+## 環境準備
+
+すでにFish-speechをお持ちの場合は、以下の指示を追加するだけで直接使用できます:
+```bash
+pip install cachetools
+```
+
+!!! note
+ コンパイルにはPythonバージョン3.12未満を使用してください。
+
+お持ちでない場合は、以下のコマンドで環境を構築してください:
+
+```bash
+sudo apt-get install portaudio19-dev
+
+pip install -e .[stable]
+```
+
+## エージェントデモの起動
+
+fish-agentを構築するには、メインフォルダで以下のコマンドを使用してください:
+
+```bash
+python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+```
+
+`--compile`引数はPython < 3.12でのみサポートされており、トークン生成を大幅に高速化します。
+
+一度にコンパイルは行われません(覚えておいてください)。
+
+次に、別のターミナルを開いて以下のコマンドを使用します:
+
+```bash
+python -m tools.e2e_webui
+```
+
+これにより、デバイス上にGradio WebUIが作成されます。
+
+モデルを初めて使用する際は、(`--compile`がTrueの場合)しばらくコンパイルが行われますので、お待ちください。
+
+## Gradio Webui
+
+
+
+
+お楽しみください!
+
+## パフォーマンス
+
+テストでは、4060搭載のラップトップではかろうじて動作しますが、非常に厳しい状態で、約8トークン/秒程度です。4090ではコンパイル時に約95トークン/秒で、これが推奨環境です。
+
+# エージェントについて
+
+このデモは初期アルファテストバージョンで、推論速度の最適化が必要で、修正を待つバグが多数あります。バグを発見した場合や修正したい場合は、issueやプルリクエストをいただけると大変嬉しく思います。
diff --git a/docs/ko/finetune.md b/docs/ko/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..85cf11f19ed34f593f2026f1d98bc906947aa39b
--- /dev/null
+++ b/docs/ko/finetune.md
@@ -0,0 +1,128 @@
+# 파인튜닝
+
+이 페이지를 열었다는 것은, 사전 학습된 퓨샷(Few-shot) 모델의 성능에 만족하지 못했다는 의미일 것입니다. 데이터셋의 성능을 향상시키기 위해 모델을 파인튜닝하고 싶으시겠죠.
+
+현재 버전에서는 'LLAMA' 부분만 파인튜닝하시면 됩니다.
+
+## LLAMA 파인튜닝
+### 1. 데이터셋 준비
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+위와 같은 형식으로 데이터셋을 변환하여 `data` 디렉토리 안에 배치하세요. 오디오 파일의 확장자는 `.mp3`, `.wav`, `.flac` 중 하나여야 하며, 주석 파일은 `.lab` 확장자를 사용해야 합니다.
+
+!!! info "데이터셋 형식"
+ `.lab` 주석 파일은 오디오의 전사 내용만 포함하면 되며, 특별한 형식이 필요하지 않습니다. 예를 들어, `hi.mp3`에서 "Hello, goodbye"라는 대사를 말한다면, `hi.lab` 파일에는 "Hello, goodbye"라는 한 줄의 텍스트만 있어야 합니다.
+
+!!! warning
+ 데이터셋에 대한 음량 정규화(loudness normalization)를 적용하는 것이 좋습니다. 이를 위해 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess)를 사용할 수 있습니다.
+
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+### 2. 시맨틱 토큰 배치 추출
+
+VQGAN 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+이후 시맨틱 토큰을 추출하기 위해 아래 명령어를 실행하세요:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+!!! note
+ 추출 속도를 높이기 위해 `--num-workers`와 `--batch-size` 값을 조정할 수 있지만, GPU 메모리 한도를 초과하지 않도록 주의하세요.
+ VITS 형식의 경우, `--filelist xxx.list`를 사용하여 파일 목록을 지정할 수 있습니다.
+
+이 명령을 실행하면 `data` 디렉토리 안에 `.npy` 파일이 생성됩니다. 다음과 같이 표시됩니다:
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. 데이터셋을 protobuf로 패킹
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+명령이 완료되면 `data` 디렉토리 안에 `quantized-dataset-ft.protos` 파일이 생성됩니다.
+
+### 4. 마지막으로, LoRA를 이용한 파인튜닝
+
+마찬가지로, `LLAMA` 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+마지막으로, 아래 명령어를 실행하여 파인튜닝을 시작할 수 있습니다:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ `batch_size`, `gradient_accumulation_steps` 등의 학습 매개변수를 GPU 메모리에 맞게 조정하려면 `fish_speech/configs/text2semantic_finetune.yaml` 파일을 수정할 수 있습니다.
+
+!!! note
+ Windows 사용자의 경우, `nccl` 문제를 피하려면 `trainer.strategy.process_group_backend=gloo`를 사용할 수 있습니다.
+
+훈련이 완료되면 [추론](inference.md) 섹션을 참고하여 음성을 생성할 수 있습니다.
+
+!!! info
+ 기본적으로 모델은 화자의 말하는 패턴만 학습하고 음색은 학습하지 않습니다. 음색의 안정성을 위해 프롬프트를 사용해야 합니다.
+ 음색을 학습하려면 훈련 단계를 늘릴 수 있지만, 이는 과적합의 위험을 초래할 수 있습니다.
+
+훈련이 끝나면 LoRA 가중치를 일반 가중치로 변환한 후에 추론을 수행해야 합니다.
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.5 \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.5-yth-lora/
+```
+
+!!! note
+ 다른 체크포인트도 시도해 볼 수 있습니다. 요구 사항에 맞는 가장 초기 체크포인트를 사용하는 것이 좋습니다. 이들은 종종 분포 밖(OOD) 데이터에서 더 좋은 성능을 발휘합니다.
diff --git a/docs/ko/index.md b/docs/ko/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..f65974f8fdf805fb7ada7b467df5a301bad7ea4f
--- /dev/null
+++ b/docs/ko/index.md
@@ -0,0 +1,215 @@
+# 소개
+
+
+
+!!! warning
+ 이 코드베이스의 불법적인 사용에 대해서는 책임을 지지 않습니다. DMCA(Digital Millennium Copyright Act) 및 해당 지역의 관련 법률을 참조하십시오.
+ 이 코드베이스와 모든 모델은 CC-BY-NC-SA-4.0 라이선스에 따라 배포됩니다.
+
+
+
+
+
+## 요구 사항
+
+- GPU 메모리: 4GB (추론용), 8GB (파인튜닝용)
+- 시스템: Linux, Windows
+
+## Windows 설정
+
+고급 Windows 사용자는 WSL2 또는 Docker를 사용하여 코드베이스를 실행하는 것을 고려할 수 있습니다.
+
+```bash
+# 파이썬 3.10 가상 환경 생성, virtualenv도 사용할 수 있습니다.
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# pytorch 설치
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+# fish-speech 설치
+pip3 install -e .
+
+# (가속 활성화) triton-windows 설치
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
+비전문 Windows 사용자는 Linux 환경 없이 프로젝트를 실행할 수 있는 다음 기본 방법을 고려할 수 있습니다 (모델 컴파일 기능 포함, 즉 `torch.compile`):
+
+1. 프로젝트 패키지 추출.
+2. `install_env.bat`을 클릭하여 환경 설치.
+3. 컴파일 가속을 활성화하려면 아래 단계를 따르세요:
+ 1. LLVM 컴파일러 다운로드:
+ - [LLVM-17.0.6 (공식 사이트)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6 (미러 사이트)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - `LLVM-17.0.6-win64.exe`를 다운로드 후 더블클릭하여 설치하고, 설치 경로 선택 시 `Add Path to Current User` 옵션을 체크하여 환경 변수를 추가합니다.
+ - 설치가 완료되었는지 확인합니다.
+ 2. Microsoft Visual C++ 재배포 가능 패키지를 다운로드하여 .dll 누락 문제 해결:
+ - [MSVC++ 14.40.33810.0 다운로드](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. Visual Studio Community Edition을 다운로드하여 LLVM의 헤더 파일 의존성을 해결:
+ - [Visual Studio 다운로드](https://visualstudio.microsoft.com/zh-hans/downloads/)
+ - Visual Studio Installer를 설치한 후 Visual Studio Community 2022를 다운로드.
+ - `Desktop development with C++` 옵션을 선택하여 설치.
+ 4. [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64) 다운로드 및 설치.
+4. `start.bat`을 더블 클릭하여 훈련 추론 WebUI 관리 인터페이스를 엽니다. 필요한 경우 아래 지침에 따라 `API_FLAGS`를 수정할 수 있습니다.
+
+!!! info "Optional"
+
+ 추론을 위해 WebUI를 사용하고자 하시나요?
+
+ 프로젝트 루트 디렉토리의 `API_FLAGS.txt` 파일을 편집하고 첫 세 줄을 아래와 같이 수정하세요:
+ ```
+ --infer
+ # --api
+ # --listen ...
+ ...
+ ```
+
+!!! info "Optional"
+
+ API 서버를 시작하고 싶으신가요?
+
+ 프로젝트 루트 디렉토리의 `API_FLAGS.txt` 파일을 편집하고 첫 세 줄을 아래와 같이 수정하세요:
+
+ ```
+ # --infer
+ --api
+ --listen ...
+ ...
+ ```
+
+!!! info "Optional"
+
+ `run_cmd.bat`을 더블 클릭하여 이 프로젝트의 conda/python 명령줄 환경에 진입할 수 있습니다.
+
+## Linux 설정
+
+[pyproject.toml](../../pyproject.toml)에서 자세한 내용을 확인하세요.
+```bash
+# 파이썬 3.10 가상 환경 생성, virtualenv도 사용할 수 있습니다.
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# (Ubuntu / Debian 사용자) sox + ffmpeg 설치
+apt install libsox-dev ffmpeg
+
+# (Ubuntu / Debian 사용자) pyaudio 설치
+apt install build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0
+
+# pytorch 설치
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+
+# fish-speech 설치
+pip3 install -e .[stable]
+```
+
+## macos 설정
+
+MPS에서 추론을 수행하려면 `--device mps` 플래그를 추가하세요.
+추론 속도 비교는 [이 PR](https://github.com/fishaudio/fish-speech/pull/461#issuecomment-2284277772)을 참조하십시오.
+
+!!! warning
+ Apple Silicon 장치에서는 `compile` 옵션이 공식적으로 지원되지 않으므로 추론 속도가 향상된다는 보장은 없습니다.
+
+```bash
+# 파이썬 3.10 가상 환경 생성, virtualenv도 사용할 수 있습니다.
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+# pytorch 설치
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+# fish-speech 설치
+pip install -e .[stable]
+```
+
+## Docker 설정
+
+1. NVIDIA Container Toolkit 설치:
+
+ Docker에서 모델 훈련 및 추론에 GPU를 사용하려면 NVIDIA Container Toolkit을 설치해야 합니다:
+
+ Ubuntu 사용자:
+
+ ```bash
+ # 저장소 추가
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
+ && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+ # nvidia-container-toolkit 설치
+ sudo apt-get update
+ sudo apt-get install -y nvidia-container-toolkit
+ # Docker 서비스 재시작
+ sudo systemctl restart docker
+ ```
+
+ 다른 Linux 배포판 사용자는: [NVIDIA Container Toolkit 설치 가이드](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)를 참조하십시오.
+
+2. fish-speech 이미지 가져오기 및 실행
+
+ ```bash
+ # 이미지 가져오기
+ docker pull fishaudio/fish-speech:latest-dev
+ # 이미지 실행
+ docker run -it \
+ --name fish-speech \
+ --gpus all \
+ -p 7860:7860 \
+ fishaudio/fish-speech:latest-dev \
+ zsh
+ # 다른 포트를 사용하려면 -p 매개변수를 YourPort:7860으로 수정하세요
+ ```
+
+3. 모델 종속성 다운로드
+
+ Docker 컨테이너 내부의 터미널에서 아래 명령어를 사용하여 필요한 `vqgan` 및 `llama` 모델을 Huggingface 리포지토리에서 다운로드합니다.
+
+ ```bash
+ huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+ ```
+
+4. 환경 변수 설정 및 WebUI 접근
+
+ Docker 컨테이너 내부의 터미널에서 `export GRADIO_SERVER_NAME="0.0.0.0"`를 입력하여 Docker 내부에서 Gradio 서비스에 외부 접근을 허용합니다.
+ 이후, 터미널에서 `python tools/run_webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다.
+
+ WSL 또는 macOS를 사용하는 경우 [http://localhost:7860](http://localhost:7860)에서 WebUI 인터페이스를 열 수 있습니다.
+
+ 서버에 배포된 경우, localhost를 서버의 IP로 교체하세요.
+
+## 변경 사항
+
+- 2024/09/10: Fish-Speech 1.4 버전으로 업데이트, 데이터셋 크기 증가 및 양자화기의 n_groups를 4에서 8로 변경.
+- 2024/07/02: Fish-Speech 1.2 버전으로 업데이트, VITS 디코더 제거 및 제로샷 능력 크게 향상.
+- 2024/05/10: Fish-Speech 1.1 버전으로 업데이트, WER 감소 및 음색 유사성을 개선하기 위해 VITS 디코더 구현.
+- 2024/04/22: Fish-Speech 1.0 버전 완료, VQGAN 및 LLAMA 모델 대폭 수정.
+- 2023/12/28: `lora` 파인튜닝 지원 추가.
+- 2023/12/27: `gradient checkpointing`, `causual sampling`, 및 `flash-attn` 지원 추가.
+- 2023/12/19: WebUI 및 HTTP API 업데이트.
+- 2023/12/18: 파인튜닝 문서 및 관련 예시 업데이트.
+- 2023/12/17: `text2semantic` 모델 업데이트, 음소 없는 모드 지원.
+- 2023/12/13: 베타 버전 출시, VQGAN 모델 및 LLAMA 기반 언어 모델(음소 지원만 포함).
+
+## 감사의 말
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/ko/inference.md b/docs/ko/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..9e639c2360f0ca6bcf46886fda489f28083fd690
--- /dev/null
+++ b/docs/ko/inference.md
@@ -0,0 +1,134 @@
+# 추론
+
+추론은 명령줄, HTTP API, 그리고 웹 UI에서 지원됩니다.
+
+!!! note
+ 전체 추론 과정은 다음의 여러 단계로 구성됩니다:
+
+ 1. VQGAN을 사용하여 약 10초 분량의 음성을 인코딩합니다.
+ 2. 인코딩된 시맨틱 토큰과 해당 텍스트를 예시로 언어 모델에 입력합니다.
+ 3. 새로운 텍스트를 입력하면, 모델이 해당하는 시맨틱 토큰을 생성합니다.
+ 4. 생성된 시맨틱 토큰을 VITS / VQGAN에 입력하여 음성을 디코딩하고 생성합니다.
+
+## 명령줄 추론
+
+필요한 `vqgan` 및 `llama` 모델을 Hugging Face 리포지토리에서 다운로드하세요.
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+### 1. 음성에서 프롬프트 생성:
+
+!!! note
+ 모델이 음색을 무작위로 선택하도록 하려면 이 단계를 건너뛸 수 있습니다.
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+이 명령을 실행하면 `fake.npy` 파일을 얻게 됩니다.
+
+### 2. 텍스트에서 시맨틱 토큰 생성:
+
+```bash
+python tools/llama/generate.py \
+ --text "변환할 텍스트" \
+ --prompt-text "참고할 텍스트" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.5" \
+ --num-samples 2 \
+ --compile
+```
+
+이 명령을 실행하면 작업 디렉토리에 `codes_N` 파일이 생성되며, N은 0부터 시작하는 정수입니다.
+
+!!! note
+ 빠른 추론을 위해 `--compile` 옵션을 사용하여 CUDA 커널을 결합할 수 있습니다 (~초당 30 토큰 -> ~초당 500 토큰).
+ `--compile` 매개변수를 주석 처리하여 가속화 옵션을 사용하지 않을 수도 있습니다.
+
+!!! info
+ bf16을 지원하지 않는 GPU의 경우 `--half` 매개변수를 사용해야 할 수 있습니다.
+
+### 3. 시맨틱 토큰에서 음성 생성:
+
+#### VQGAN 디코더
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+## HTTP API 추론
+
+추론을 위한 HTTP API를 제공하고 있습니다. 아래의 명령어로 서버를 시작할 수 있습니다:
+
+```bash
+python -m tools.api_server \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+추론 속도를 높이고 싶다면 `--compile` 매개변수를 추가할 수 있습니다.
+
+이후, http://127.0.0.1:8080/ 에서 API를 확인하고 테스트할 수 있습니다.
+
+아래는 `tools/api_client.py`를 사용하여 요청을 보내는 예시입니다.
+
+```bash
+python -m tools.api_client \
+ --text "입력할 텍스트" \
+ --reference_audio "참고 음성 경로" \
+ --reference_text "참고 음성의 텍스트 내용" \
+ --streaming True
+```
+
+위 명령은 참고 음성 정보를 바탕으로 원하는 음성을 합성하고, 스트리밍 방식으로 반환합니다.
+
+다음 예시는 여러 개의 참고 음성 경로와 텍스트를 한꺼번에 사용할 수 있음을 보여줍니다. 명령에서 공백으로 구분하여 입력합니다.
+
+```bash
+python -m tools.api_client \
+ --text "입력할 텍스트" \
+ --reference_audio "참고 음성 경로1" "참고 음성 경로2" \
+ --reference_text "참고 음성 텍스트1" "참고 음성 텍스트2"\
+ --streaming False \
+ --output "generated" \
+ --format "mp3"
+```
+
+위 명령어는 여러 참고 음성 정보를 바탕으로 `MP3` 형식의 음성을 합성하여, 현재 디렉토리에 `generated.mp3`로 저장합니다.
+
+`--reference_audio`와 `--reference_text` 대신에 `--reference_id`(하나만 사용 가능)를 사용할 수 있습니다. 프로젝트 루트 디렉토리에 `references/` 폴더를 만들어 해당 음성과 주석 텍스트를 넣어야 합니다. 참고 음성은 최대 90초까지 지원됩니다.
+
+!!! info
+ 제공되는 파라미터는 `python -m tools.api_client -h`를 사용하여 확인할 수 있습니다.
+
+## GUI 추론
+[클라이언트 다운로드](https://github.com/AnyaCoder/fish-speech-gui/releases)
+
+## WebUI 추론
+
+다음 명령으로 WebUI를 시작할 수 있습니다:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+> 추론 속도를 높이고 싶다면 `--compile` 매개변수를 추가할 수 있습니다.
+
+!!! note
+ 라벨 파일과 참고 음성 파일을 미리 메인 디렉토리의 `references` 폴더에 저장해 두면, WebUI에서 바로 호출할 수 있습니다. (해당 폴더는 직접 생성해야 합니다.)
+
+!!! note
+ WebUI를 구성하기 위해 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME`과 같은 Gradio 환경 변수를 사용할 수 있습니다.
+
+즐기세요!
diff --git a/docs/ko/samples.md b/docs/ko/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..5286a3647a120f0ba1d60eff6bd9a83575731aaf
--- /dev/null
+++ b/docs/ko/samples.md
@@ -0,0 +1,137 @@
+# 샘플
+
+ver 1.4
+
+## Credits
+[Seed-TTS (2024)](https://bytedancespeech.github.io/seedtts_tech_report/)에 감사드리며, 평가 데이터를 제공해 주셔서 이 데모를 완성할 수 있었습니다.
+
+모든 프롬프트 음성은 Seed-TTS 효과 데모 페이지에서 가져왔으며, 모든 생성된 음성은 fish-speech 버전 1.4에서 첫 번째로 생성된 것입니다.
+
+## 제로샷 인컨텍스트 학습
+- TODO: 한국어 제로샷 인컨텍스트 학습 샘플 추가. (현재는 영어와 중국어 데모만 제공됩니다.)
+
+
+
+
+ 언어
+ 프롬프트
+ 동일 언어 생성
+ 교차 언어 생성
+
+
+
+
+ EN
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.
+ Your browser does not support the audio element. 顿时,气氛变得沉郁起来。乍看之下,一切的困扰仿佛都围绕在我身边。我皱着眉头,感受着那份压力,但我知道我不能放弃,不能认输。于是,我深吸一口气,心底的声音告诉我:“无论如何,都要冷静下来,重新开始。”
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. Dealing with family secrets is never easy. Yet, sometimes, omission is a form of protection, intending to safeguard some from the harsh truths. One day, I hope you understand the reasons behind my actions. Until then, Anna, please, bear with me.
+ Your browser does not support the audio element. 处理家庭秘密从来都不是一件容易的事。然而,有时候,隐瞒是一种保护形式,旨在保护一些人免受残酷的真相伤害。有一天,我希望你能理解我行为背后的原因。在那之前,安娜,请容忍我。
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. The combinations of different textures and flavors create a perfect harmony. The succulence of the steak, the tartness of the cranberries, the crunch of pine nuts, and creaminess of blue cheese make it a truly delectable delight. Enjoy your culinary adventure!
+ Your browser does not support the audio element. 听着你的话,我心里五味杂陈。虽然我愿意一直在你身边,承担一切不幸,但我知道只有让你自己面对,才能真正让你变得更强大。所以,你要记得,无论面对何种困难,都请你坚强,我会在心里一直支持你的。
+
+
+ ZH
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
+ Your browser does not support the audio element. Suddenly, there was a burst of laughter beside me. I looked at them, stood up straight with high spirit, shook the slightly fleshy arms, and smiled lightly, saying, "The flesh on my body is to hide my bursting charm. Otherwise, wouldn't it scare you?"
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 他闭上眼睛,期望这一切都能过去。然而,当他再次睁开眼睛,眼前的景象让他不禁倒吸一口气。雾气中出现的禁闭岛,陌生又熟悉,充满未知的危险。他握紧拳头,心知他的生活即将发生翻天覆地的改变。
+ Your browser does not support the audio element. He closed his eyes, expecting that all of this could pass. However, when he opened his eyes again, the sight in front of him made him couldn't help but take a deep breath. The closed island that appeared in the fog, strange and familiar, was full of unknown dangers. He tightened his fist, knowing that his life was about to undergo earth-shaking changes.
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 顿时,气氛变得沉郁起来。乍看之下,一切的困扰仿佛都围绕在我身边。我皱着眉头,感受着那份压力,但我知道我不能放弃,不能认输。于是,我深吸一口气,心底的声音告诉我:“无论如何,都要冷静下来,重新开始。”
+ Your browser does not support the audio element. Suddenly, the atmosphere became gloomy. At first glance, all the troubles seemed to surround me. I frowned, feeling that pressure, but I know I can't give up, can't admit defeat. So, I took a deep breath, and the voice in my heart told me, "Anyway, must calm down and start again."
+
+
+
+
+## 화자 파인튜닝
+
+
+
+
+
+ 텍스트
+ 생성된 음성
+
+
+
+
+ 화자1
+ 好呀,哈哈哈哈哈,喜欢笑的人运气都不会差哦,希望你每天笑口常开~
+ Your browser does not support the audio element.
+
+
+ 哇!恭喜你中了大乐透,八百万可真不少呢!有什么特别的计划或想法吗?
+ Your browser does not support the audio element.
+
+
+ 哼,你这么问是想请本小姐吃饭吗?如果对象是你的话,那也不是不可以。
+ Your browser does not support the audio element.
+
+
+ 화자2
+ 是呀,他还想换个地球仪哈哈哈,看来给你积累了一些快乐值了,你还想不想再听一个其他的笑话呀?
+ Your browser does not support the audio element.
+
+
+ 嘿嘿,你是不是也想拥有甜甜的恋爱呢?《微微一笑很倾城》是你的不二选择,男女主是校花校草类型,他们通过游戏结识,再到两人见面,全程没有一点误会,真的齁甜,想想都忍不住“姨妈笑”~
+ Your browser does not support the audio element.
+
+
+ 小傻瓜,嗯……算是个很可爱很亲切的名字,有点“独特”哦,不过我有些好奇,你为什么会给我选这个昵称呢?
+ Your browser does not support the audio element.
+
+
+
+
+
+## 콘텐츠 편집
+
+
+
+ 언어
+ 원본 텍스트
+ 원본 음성
+ 목표 텍스트
+ 편집된 음성
+
+
+
+ EN
+ They can't order me to stop dreaming. If you dream a thing more than once, it's sure to come true. Have faith in your dreams, and someday your rainbow will come shining through.
+ Your browser does not support the audio element.
+ They can't require me to stop imagining. If you envision a thing more than once, it's bound to come about . Have trust in your visions , and someday your radiance will come beaming through.
+ Your browser does not support the audio element.
+
+
+ Are you familiar with it? Slice the steak and place the strips on top, then garnish with the dried cranberries, pine nuts, and blue cheese. I wonder how people rationalise the decision?
+ Your browser does not support the audio element.
+ Are you acquainted with it? Cut the pork and place the strips on top, then garnish with the dried cherries, almonds, and feta cheese. I query how people justify the choice?
+ Your browser does not support the audio element.
+
+
+ ZH
+ 自古以来,庸君最怕党政了,可圣君他就不怕,不但不怕,反能利用。要我说,你就让明珠索额图互相争宠,只要你心里明白,左右逢源,你就能立于不败之地。
+ Your browser does not support the audio element.
+ 从古至今 ,庸君最怕朝纲了 ,可明 君他就不怕,不但不怕,反能借助 。要我说,你就让李四张三 互相争宠,只要你心里清楚 ,左右周旋 ,你就能处 于不败之境 。
+ Your browser does not support the audio element.
+
+
+ 对,这就是我,万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。
+ Your browser does not support the audio element.
+ 对,这就是我,众人尊崇 的太白金星 ,虽然有点娃娃脸 ,但也遮 不住我迷人 的魅力。
+ Your browser does not support the audio element.
+
+
+
diff --git a/docs/ko/start_agent.md b/docs/ko/start_agent.md
new file mode 100644
index 0000000000000000000000000000000000000000..c4d085de258ec0d19ed08632ea7b4129f0356b10
--- /dev/null
+++ b/docs/ko/start_agent.md
@@ -0,0 +1,80 @@
+# 에이전트 시작하기
+
+!!! note
+ 전체 문서는 claude3.5 Sonnet에 의해 번역되었으며, 원어민인 경우 번역에 문제가 있다고 생각되면 이슈나 풀 리퀘스트를 보내주셔서 대단히 감사합니다!
+
+## 요구사항
+
+- GPU 메모리: 최소 8GB(양자화 사용 시), 16GB 이상 권장
+- 디스크 사용량: 10GB
+
+## 모델 다운로드
+
+다음 명령어로 모델을 받을 수 있습니다:
+
+```bash
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
+```
+
+'checkpoints' 폴더에 파일들을 넣으세요.
+
+또한 [inference](inference.md)에 설명된 대로 fish-speech 모델도 다운로드해야 합니다.
+
+checkpoints에는 2개의 폴더가 있어야 합니다.
+
+`checkpoints/fish-speech-1.4`와 `checkpoints/fish-agent-v0.1-3b`입니다.
+
+## 환경 준비
+
+이미 Fish-speech가 있다면 다음 명령어를 추가하여 바로 사용할 수 있습니다:
+```bash
+pip install cachetools
+```
+
+!!! 참고
+ 컴파일을 위해 Python 3.12 미만 버전을 사용해 주세요.
+
+없다면 아래 명령어를 사용하여 환경을 구축하세요:
+
+```bash
+sudo apt-get install portaudio19-dev
+
+pip install -e .[stable]
+```
+
+## 에이전트 데모 실행
+
+fish-agent를 구축하려면 메인 폴더에서 아래 명령어를 사용하세요:
+
+```bash
+python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+```
+
+`--compile` 인자는 Python < 3.12에서만 지원되며, 토큰 생성 속도를 크게 향상시킵니다.
+
+한 번에 컴파일되지 않습니다(기억해 두세요).
+
+그런 다음 다른 터미널을 열고 다음 명령어를 사용하세요:
+
+```bash
+python -m tools.e2e_webui
+```
+
+이렇게 하면 기기에 Gradio WebUI가 생성됩니다.
+
+모델을 처음 사용할 때는 (`--compile`이 True인 경우) 잠시 컴파일이 진행되므로 기다려 주세요.
+
+## Gradio Webui
+
+
+
+
+즐거운 시간 되세요!
+
+## 성능
+
+테스트 결과, 4060 노트북은 겨우 실행되며 매우 부하가 큰 상태로, 초당 약 8토큰 정도만 처리합니다. 4090은 컴파일 상태에서 초당 약 95토큰을 처리하며, 이것이 저희가 권장하는 사양입니다.
+
+# 에이전트 소개
+
+이 데모는 초기 알파 테스트 버전으로, 추론 속도 최적화가 필요하며 수정해야 할 버그가 많이 있습니다. 버그를 발견하거나 수정하고 싶으시다면 이슈나 풀 리퀘스트를 보내주시면 매우 감사하겠습니다.
diff --git a/docs/pt/finetune.md b/docs/pt/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..7e7eb5c89570a52d04dbde1d17adf9031d09abd8
--- /dev/null
+++ b/docs/pt/finetune.md
@@ -0,0 +1,128 @@
+# Ajuste Fino
+
+É óbvio que ao abrir esta página, você não deve estar muito satisfeito com o desempenho do modelo pré-treinado com poucos exemplos. Você pode querer ajustar o modelo para melhorar seu desempenho em seu conjunto de dados.
+
+Na atual versão, a única coisa que você precisa ajustar é a parte do 'LLAMA'.
+
+## Ajuste Fino do LLAMA
+### 1. Preparando o conjunto de dados
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+Você precisa converter seu conjunto de dados para o formato acima e colocá-lo em `data`. O arquivo de áudio pode ter as extensões `.mp3`, `.wav` ou `.flac`, e o arquivo de anotação deve ter a extensão `.lab`.
+
+!!! info
+ O arquivo de anotação `.lab` deve conter apenas a transcrição do áudio, sem a necessidade de formatação especial. Por exemplo, se o arquivo `hi.mp3` disser "Olá, tchau", o arquivo `hi.lab` conterá uma única linha de texto: "Olá, tchau".
+
+!!! warning
+ É recomendado aplicar normalização de volume ao conjunto de dados. Você pode usar o [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) para fazer isso.
+
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+
+### 2. Extração em lote de tokens semânticos
+
+Certifique-se de ter baixado os pesos do VQGAN. Se não, execute o seguinte comando:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+Em seguida, você pode executar o seguinte comando para extrair os tokens semânticos:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+!!! note
+ Você pode ajustar `--num-workers` e `--batch-size` para aumentar a velocidade de extração, mas certifique-se de não exceder o limite de memória da sua GPU.
+ Para o formato VITS, você pode especificar uma lista de arquivos usando `--filelist xxx.list`.
+
+Este comando criará arquivos `.npy` no diretório `data`, como mostrado abaixo:
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. Empacotar o conjunto de dados em protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+Após executar o comando, você deverá ver o arquivo `quantized-dataset-ft.protos` no diretório `data`.
+
+### 4. E finalmente, chegamos ao ajuste fino com LoRA
+
+Da mesma forma, certifique-se de ter baixado os pesos do `LLAMA`. Se não, execute o seguinte comando:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+E então, execute o seguinte comando para iniciar o ajuste fino:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ Se quiser, você pode modificar os parâmetros de treinamento, como `batch_size`, `gradient_accumulation_steps`, etc., para se ajustar à memória da sua GPU, modificando `fish_speech/configs/text2semantic_finetune.yaml`.
+
+!!! note
+ Para usuários do Windows, é recomendado usar `trainer.strategy.process_group_backend=gloo` para evitar problemas com `nccl`.
+
+Após concluir o treinamento, consulte a seção [inferência](inference.md).
+
+!!! info
+ Por padrão, o modelo aprenderá apenas os padrões de fala do orador e não o timbre. Ainda pode ser preciso usar prompts para garantir a estabilidade do timbre.
+ Se quiser que ele aprenda o timbre, aumente o número de etapas de treinamento, mas isso pode levar ao overfitting (sobreajuste).
+
+Após o treinamento, é preciso converter os pesos do LoRA em pesos regulares antes de realizar a inferência.
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.5 \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.5-yth-lora/
+```
+!!! note
+ É possível também tentar outros checkpoints. Sugerimos usar o checkpoint que melhor atenda aos seus requisitos, pois eles geralmente têm um desempenho melhor em dados fora da distribuição (OOD).
diff --git a/docs/pt/index.md b/docs/pt/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..46fbc37605eca3f1875d1dc6c902a1bb7f2006c3
--- /dev/null
+++ b/docs/pt/index.md
@@ -0,0 +1,210 @@
+# Introdução
+
+
+
+!!! warning
+ Não nos responsabilizamos por qualquer uso ilegal do código-fonte. Consulte as leis locais sobre DMCA (Digital Millennium Copyright Act) e outras leis relevantes em sua região.
+ Este repositório de código e os modelos são distribuídos sob a licença CC-BY-NC-SA-4.0.
+
+
+
+
+
+## Requisitos
+
+- Memória da GPU: 4GB (para inferência), 8GB (para ajuste fino)
+- Sistema: Linux, Windows
+
+## Configuração do Windows
+
+Usuários profissionais do Windows podem considerar o uso do WSL2 ou Docker para executar a base de código.
+
+```bash
+# Crie um ambiente virtual Python 3.10, também é possível usar o virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# Instale o pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+# Instale o fish-speech
+pip3 install -e .
+
+# (Ativar aceleração) Instalar triton-windows
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
+Usuários não profissionais do Windows podem considerar os seguintes métodos básicos para executar o projeto sem um ambiente Linux (com capacidades de compilação de modelo, ou seja, `torch.compile`):
+
+1. Extraia o pacote do projeto.
+2. Clique em `install_env.bat` para instalar o ambiente.
+3. Se você quiser ativar a aceleração de compilação, siga estas etapas:
+ 1. Baixe o compilador LLVM nos seguintes links:
+ - [LLVM-17.0.6 (Download do site oficial)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6 (Download do site espelho)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - Após baixar o `LLVM-17.0.6-win64.exe`, clique duas vezes para instalar, selecione um local de instalação apropriado e, o mais importante, marque a opção `Add Path to Current User` para adicionar a variável de ambiente.
+ - Confirme que a instalação foi concluída.
+ 2. Baixe e instale o Microsoft Visual C++ Redistributable para resolver possíveis problemas de arquivos .dll ausentes:
+ - [Download do MSVC++ 14.40.33810.0](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. Baixe e instale o Visual Studio Community Edition para obter as ferramentas de compilação do MSVC++ e resolver as dependências dos arquivos de cabeçalho do LLVM:
+ - [Download do Visual Studio](https://visualstudio.microsoft.com/pt-br/downloads/)
+ - Após instalar o Visual Studio Installer, baixe o Visual Studio Community 2022.
+ - Conforme mostrado abaixo, clique no botão `Modificar`, encontre a opção `Desenvolvimento de área de trabalho com C++` e selecione para fazer o download.
+ 4. Baixe e instale o [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
+4. Clique duas vezes em `start.bat` para abrir a interface de gerenciamento WebUI de inferência de treinamento. Se necessário, você pode modificar as `API_FLAGS` conforme mostrado abaixo.
+
+!!! info "Opcional"
+ Você quer iniciar o WebUI de inferência?
+ Edite o arquivo `API_FLAGS.txt` no diretório raiz do projeto e modifique as três primeiras linhas como segue:
+ ```
+ --infer
+ # --api
+ # --listen ...
+ ...
+ ```
+
+!!! info "Opcional"
+ Você quer iniciar o servidor de API?
+ Edite o arquivo `API_FLAGS.txt` no diretório raiz do projeto e modifique as três primeiras linhas como segue:
+
+ ```
+ # --infer
+ --api
+ --listen ...
+ ...
+ ```
+
+!!! info "Opcional"
+ Clique duas vezes em `run_cmd.bat` para entrar no ambiente de linha de comando conda/python deste projeto.
+
+
+## Configuração para Linux
+
+Para mais detalhes, consulte [pyproject.toml](../../pyproject.toml).
+```bash
+# Crie um ambiente virtual python 3.10, você também pode usar virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# Instale o pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+
+# Para os Usuário do Ubuntu / Debian: Instale o sox + ffmpeg
+apt install libsox-dev ffmpeg
+
+# Para os Usuário do Ubuntu / Debian: Instale o pyaudio
+apt install build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0
+
+# Instale o fish-speech
+pip3 install -e .[stable]
+```
+
+## Configuração para macos
+
+Se você quiser realizar inferências no MPS, adicione a flag `--device mps`.
+Para uma comparação das velocidades de inferência, consulte [este PR](https://github.com/fishaudio/fish-speech/pull/461#issuecomment-2284277772).
+
+!!! aviso
+ A opção `compile` não é oficialmente suportada em dispositivos Apple Silicon, então não há garantia de que a velocidade de inferência irá melhorar.
+
+```bash
+# create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+# install pytorch
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+# install fish-speech
+pip install -e .[stable]
+```
+
+## Configuração do Docker
+
+1. Instale o NVIDIA Container Toolkit:
+
+ Para usar a GPU com Docker para treinamento e inferência de modelos, você precisa instalar o NVIDIA Container Toolkit:
+
+ Para usuários Ubuntu:
+
+ ```bash
+ # Adicione o repositório remoto
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
+ && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+ # Instale o nvidia-container-toolkit
+ sudo apt-get update
+ sudo apt-get install -y nvidia-container-toolkit
+ # Reinicie o serviço Docker
+ sudo systemctl restart docker
+ ```
+
+ Para usuários de outras distribuições Linux, consulte o guia de instalação: [NVIDIA Container Toolkit Install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
+
+2. Baixe e execute a imagem fish-speech
+
+ ```shell
+ # Baixe a imagem
+ docker pull fishaudio/fish-speech:latest-dev
+ # Execute a imagem
+ docker run -it \
+ --name fish-speech \
+ --gpus all \
+ -p 7860:7860 \
+ fishaudio/fish-speech:latest-dev \
+ zsh
+ # Se precisar usar outra porta, modifique o parâmetro -p para YourPort:7860
+ ```
+
+3. Baixe as dependências do modelo
+
+ Certifique-se de estar no terminal do contêiner Docker e, em seguida, baixe os modelos necessários `vqgan` e `llama` do nosso repositório HuggingFace.
+
+ ```bash
+ huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+ ```
+
+4. Configure as variáveis de ambiente e acesse a WebUI
+
+ No terminal do contêiner Docker, digite `export GRADIO_SERVER_NAME="0.0.0.0"` para permitir o acesso externo ao serviço gradio dentro do Docker.
+ Em seguida, no terminal do contêiner Docker, digite `python tools/run_webui.py` para iniciar o serviço WebUI.
+
+ Se estiver usando WSL ou MacOS, acesse [http://localhost:7860](http://localhost:7860) para abrir a interface WebUI.
+
+ Se estiver implantando em um servidor, substitua localhost pelo IP do seu servidor.
+
+## Histórico de Alterações
+- 10/09/2024: Fish-Speech atualizado para a versão 1.4, aumentado o tamanho do conjunto de dados, quantizer n_groups 4 -> 8.
+- 02/07/2024: Fish-Speech atualizado para a versão 1.2, removido o Decodificador VITS e aprimorado consideravelmente a capacidade de zero-shot.
+- 10/05/2024: Fish-Speech atualizado para a versão 1.1, implementado o decodificador VITS para reduzir a WER e melhorar a similaridade de timbre.
+- 22/04/2024: Finalizada a versão 1.0 do Fish-Speech, modificados significativamente os modelos VQGAN e LLAMA.
+- 28/12/2023: Adicionado suporte para ajuste fino `lora`.
+- 27/12/2023: Adicionado suporte para `gradient checkpointing`, `causual sampling` e `flash-attn`.
+- 19/12/2023: Atualizada a interface web e a API HTTP.
+- 18/12/2023: Atualizada a documentação de ajuste fino e exemplos relacionados.
+- 17/12/2023: Atualizado o modelo `text2semantic`, suportando o modo sem fonemas.
+- 13/12/2023: Versão beta lançada, incluindo o modelo VQGAN e um modelo de linguagem baseado em LLAMA (suporte apenas a fonemas).
+
+## Agradecimentos
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/pt/inference.md b/docs/pt/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..daae046da921bfc55289a406c669af53c0a7b320
--- /dev/null
+++ b/docs/pt/inference.md
@@ -0,0 +1,114 @@
+# Inferência
+
+Suporte para inferência por linha de comando, API HTTP e interface web (WebUI).
+
+!!! note
+ O processo de raciocínio, em geral, consiste em várias partes:
+
+ 1. Codificar cerca de 10 segundos de voz usando VQGAN.
+ 2. Inserir os tokens semânticos codificados e o texto correspondente no modelo de linguagem como um exemplo.
+ 3. Dado um novo trecho de texto, fazer com que o modelo gere os tokens semânticos correspondentes.
+ 4. Inserir os tokens semânticos gerados no VITS / VQGAN para decodificar e gerar a voz correspondente.
+
+## Inferência por Linha de Comando
+
+Baixe os modelos `vqgan` e `llama` necessários do nosso repositório Hugging Face.
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+### 1. Gerar prompt a partir da voz:
+
+!!! note
+ Se quiser permitir que o modelo escolha aleatoriamente um timbre de voz, pule esta etapa.
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+Você deverá obter um arquivo `fake.npy`.
+
+### 2. Gerar tokens semânticos a partir do texto:
+
+```bash
+python tools/llama/generate.py \
+ --text "O texto que você deseja converter" \
+ --prompt-text "Seu texto de referência" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.5" \
+ --num-samples 2 \
+ --compile
+```
+
+Este comando criará um arquivo `codes_N` no diretório de trabalho, onde N é um número inteiro começando de 0.
+
+!!! note
+ Use `--compile` para fundir kernels CUDA para ter uma inferência mais rápida (~30 tokens/segundo -> ~500 tokens/segundo).
+ Mas, se não planeja usar a aceleração CUDA, comente o parâmetro `--compile`.
+
+!!! info
+ Para GPUs que não suportam bf16, pode ser necessário usar o parâmetro `--half`.
+
+### 3. Gerar vocais a partir de tokens semânticos:
+
+#### Decodificador VQGAN
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+## Inferência por API HTTP
+
+Fornecemos uma API HTTP para inferência. O seguinte comando pode ser usado para iniciar o servidor:
+
+```bash
+python -m tools.api_server \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+> Para acelerar a inferência, adicione o parâmetro `--compile`.
+
+Depois disso, é possível visualizar e testar a API em http://127.0.0.1:8080/.
+
+Abaixo está um exemplo de envio de uma solicitação usando `tools/api_client.py`.
+
+```bash
+python -m tools.api_client \
+ --text "Texto a ser inserido" \
+ --reference_audio "Caminho para o áudio de referência" \
+ --reference_text "Conteúdo de texto do áudio de referência" \
+ --streaming True
+```
+
+O comando acima indica a síntese do áudio desejada de acordo com as informações do áudio de referência e a retorna em modo de streaming.
+
+!!! info
+ Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.api_client -h`
+
+## Inferência por WebUI
+
+Para iniciar a WebUI de Inferência execute o seguinte comando:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+> Para acelerar a inferência, adicione o parâmetro `--compile`.
+
+!!! note
+ Você pode salvar antecipadamente o arquivo de rótulos e o arquivo de áudio de referência na pasta `references` do diretório principal (que você precisa criar), para que possa chamá-los diretamente na WebUI.
+
+!!! note
+ É possível usar variáveis de ambiente do Gradio, como `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME`, para configurar a WebUI.
+
+Divirta-se!
diff --git a/docs/pt/samples.md b/docs/pt/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..25042475692d5412ce1f78f4bf29552d974ab457
--- /dev/null
+++ b/docs/pt/samples.md
@@ -0,0 +1,225 @@
+# Amostras
+
+A demonstração da versão 1.4 foi atualizada [aqui](https://speech.fish.audio/samples/)
+
+As amostras da v1.2 estão disponíveis em [Bilibili](https://www.bilibili.com/video/BV1wz421B71D/).
+
+As seguintes amostras são do modelo v1.1.
+
+## Frase em Chinês 1
+```
+人间灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。
+流入她所注视的世间,也流入她如湖水般澄澈的目光。
+```
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Nahida (Genshin Impact)
+
+
+
+
+ Zhongli (Genshin Impact)
+
+
+
+
+ Furina (Genshin Impact)
+
+
+
+
+ Orador Aleatório 1
+ -
+
+
+
+ Orador Aleatório 2
+ -
+
+
+
+
+
+
+## Frase em Chinês 2
+```
+你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么?
+我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊?
+你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊,
+搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗?
+一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊?
+```
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Nahida (Genshin Impact)
+
+
+
+
+ Orador Aleatório
+ -
+
+
+
+
+
+
+## Frase em Chinês 3
+```
+大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练,
+我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。
+作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。
+```
+
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Orador Aleatório
+ -
+
+
+
+
+
+## Frase em Inglês 1
+
+```
+In the realm of advanced technology, the evolution of artificial intelligence stands as a
+monumental achievement. This dynamic field, constantly pushing the boundaries of what
+machines can do, has seen rapid growth and innovation. From deciphering complex data
+patterns to driving cars autonomously, AI's applications are vast and diverse.
+```
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Orador Aleatório 1
+ -
+
+
+
+ Orador Aleatório 2
+ -
+
+
+
+
+
+## Frase em Inglês 2
+```
+Hello everyone, I am an open-source text-to-speech model developed by
+Fish Audio. After training with 150,000 hours of data, I have become proficient
+in Chinese, Japanese, and English, and my language processing abilities
+are close to human level. My voice is capable of a wide range of expressions.
+As a model with only hundreds of millions of parameters, I believe community
+members can easily run and fine-tune me on their personal devices, allowing
+me to serve as your personal voice assistant.
+```
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Orador Aleatório
+ -
+
+
+
+
+
+## Frase em Japonês 1
+
+```
+先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を
+押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か
+ら自動運転車の操縦まで、AIの応用は広範囲に及びます。
+```
+
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Orador Aleatório 1
+ -
+
+
+
+ Orador Aleatório 2
+ -
+
+
+
+
+
+## Frase em Japonês 2
+```
+皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ
+キストから音声への変換モデルです。15万時間のデータトレーニングを経て、
+中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。
+声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ
+のメンバーが個人のデバイスで簡単に実行し、微調整することができると
+信じています。これにより、私を個人の音声アシスタントとして活用できます。
+```
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Orador Aleatório
+ -
+
+
+
+
diff --git a/docs/pt/start_agent.md b/docs/pt/start_agent.md
new file mode 100644
index 0000000000000000000000000000000000000000..da6eed54e8ae3abeb498c501187d5651bcdeaf57
--- /dev/null
+++ b/docs/pt/start_agent.md
@@ -0,0 +1,80 @@
+# Iniciar Agente
+
+!!! note
+ Todo o documento foi traduzido por claude3.5 Sonnet, se você for um falante nativo e achar a tradução problemática, muito obrigado por nos enviar um problema ou uma solicitação pull!
+
+## Requisitos
+
+- Memória GPU: No mínimo 8GB (com quantização), 16GB ou mais é recomendado.
+- Uso de disco: 10GB
+
+## Download do Modelo
+
+Você pode obter o modelo através de:
+
+```bash
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
+```
+
+Coloque-os na pasta 'checkpoints'.
+
+Você também precisará do modelo fish-speech que pode ser baixado seguindo as instruções em [inference](inference.md).
+
+Então haverá 2 pastas em checkpoints.
+
+O `checkpoints/fish-speech-1.4` e `checkpoints/fish-agent-v0.1-3b`
+
+## Preparação do Ambiente
+
+Se você já tem o Fish-speech, pode usar diretamente adicionando a seguinte instrução:
+```bash
+pip install cachetools
+```
+
+!!! nota
+ Por favor, use a versão Python abaixo de 3.12 para compilação.
+
+Se você não tem, use os comandos abaixo para construir seu ambiente:
+
+```bash
+sudo apt-get install portaudio19-dev
+
+pip install -e .[stable]
+```
+
+## Iniciar a Demo do Agente
+
+Para construir o fish-agent, use o comando abaixo na pasta principal:
+
+```bash
+python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+```
+
+O argumento `--compile` só suporta Python < 3.12, o que aumentará muito a velocidade de geração de tokens.
+
+Não será compilado de uma vez (lembre-se).
+
+Então abra outro terminal e use o comando:
+
+```bash
+python -m tools.e2e_webui
+```
+
+Isso criará uma WebUI Gradio no dispositivo.
+
+Quando você usar o modelo pela primeira vez, ele irá compilar (se `--compile` estiver True) por um curto período, então aguarde com paciência.
+
+## Gradio Webui
+
+
+
+
+Divirta-se!
+
+## Desempenho
+
+Em nossos testes, um laptop com 4060 mal consegue rodar, ficando muito sobrecarregado, gerando apenas cerca de 8 tokens/s. A 4090 gera cerca de 95 tokens/s com compilação, que é o que recomendamos.
+
+# Sobre o Agente
+
+A demo é uma versão alpha inicial de teste, a velocidade de inferência precisa ser otimizada, e há muitos bugs aguardando correção. Se você encontrou um bug ou quer corrigi-lo, ficaremos muito felizes em receber uma issue ou um pull request.
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d6e145dbea1b9b26b2bddd7500e3f270b3eb0009
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,3 @@
+mkdocs-material
+mkdocs-static-i18n[material]
+mkdocs[i18n]
diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css
new file mode 100644
index 0000000000000000000000000000000000000000..a88af87b3cdbfd2d6b05f39877d5821bb7ebe119
--- /dev/null
+++ b/docs/stylesheets/extra.css
@@ -0,0 +1,3 @@
+.md-grid {
+ max-width: 1440px;
+}
diff --git a/docs/zh/finetune.md b/docs/zh/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..1b65f8b1486b6679e3fc9465359e4e62a92f93fa
--- /dev/null
+++ b/docs/zh/finetune.md
@@ -0,0 +1,139 @@
+# 微调
+
+显然, 当你打开这个页面的时候, 你已经对预训练模型 zero-shot 的效果不算满意. 你想要微调一个模型, 使得它在你的数据集上表现更好.
+
+在目前版本,你只需要微调'LLAMA'部分即可.
+
+## LLAMA 微调
+### 1. 准备数据集
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+你需要将数据集转为以上格式, 并放到 `data` 下, 音频后缀可以为 `.mp3`, `.wav` 或 `.flac`, 标注文件后缀建议为 `.lab`.
+
+!!! info
+ 标注文件 `.lab` 仅需包含音频的转写文本,无需遵循特殊格式要求。例如,如果 `hi.mp3` 中的内容是“你好,再见。”,那么 `hi.lab` 文件中只需包含一行文本:“你好,再见”。
+
+!!! warning
+ 建议先对数据集进行响度匹配, 你可以使用 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) 来完成这一步骤.
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+### 2. 批量提取语义 token
+
+确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+对于中国大陆用户, 可使用 mirror 下载.
+
+```bash
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+随后可运行以下命令来提取语义 token:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+!!! note
+ 你可以调整 `--num-workers` 和 `--batch-size` 来提高提取速度, 但是请注意不要超过你的显存限制.
+
+该命令会在 `data` 目录下创建 `.npy` 文件, 如下所示:
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. 打包数据集为 protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+命令执行完毕后, 你应该能在 `data` 目录下看到 `protos` 文件.
+
+
+### 4. 最后, 使用 LoRA 进行微调
+
+同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+对于中国大陆用户, 可使用 mirror 下载.
+
+```bash
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+最后, 你可以运行以下命令来启动微调:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ 你可以通过修改 `fish_speech/configs/text2semantic_finetune.yaml` 来修改训练参数如 `batch_size`, `gradient_accumulation_steps` 等, 来适应你的显存.
+
+!!! note
+ 对于 Windows 用户, 你可以使用 `trainer.strategy.process_group_backend=gloo` 来避免 `nccl` 的问题.
+
+训练结束后, 你可以参考 [推理](inference.md) 部分来测试你的模型.
+
+!!! info
+ 默认配置下, 基本只会学到说话人的发音方式, 而不包含音色, 你依然需要使用 prompt 来保证音色的稳定性.
+ 如果你想要学到音色, 请将训练步数调大, 但这有可能会导致过拟合.
+
+训练完成后, 你需要先将 loRA 的权重转为普通权重, 然后再进行推理.
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.5 \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.5-yth-lora/
+```
+
+!!! note
+ 你也可以尝试其他的 checkpoint, 我们建议你使用最早的满足你要求的 checkpoint, 他们通常在 OOD 上表现更好.
diff --git a/docs/zh/index.md b/docs/zh/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..830258bcf2424a7c9df7e15b135040344e4c36d0
--- /dev/null
+++ b/docs/zh/index.md
@@ -0,0 +1,218 @@
+# 介绍
+
+
+
+!!! warning "警告"
+ 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
+ 此代码库与所有模型根据 CC-BY-NC-SA-4.0 许可证发布.
+
+
+
+
+
+## 要求
+
+- GPU 内存: 4GB (用于推理), 8GB (用于微调)
+- 系统: Linux, Windows
+
+## Windows 配置
+
+Windows 专业用户可以考虑 WSL2 或 docker 来运行代码库。
+
+```bash
+# 创建一个 python 3.10 虚拟环境, 你也可以用 virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# 安装 pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+# 安装 fish-speech
+pip3 install -e .
+
+# (开启编译加速) 安装 triton-windows
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
+Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法(附带模型编译功能,即 `torch.compile`):
+
+1. 解压项目压缩包。
+2. 点击 `install_env.bat` 安装环境。
+3. 若需要开启编译加速则执行这一步:
+ 1. 使用如下链接下载 LLVM 编译器。
+ - [LLVM-17.0.6(原站站点下载)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6(镜像站点下载)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - 下载完 `LLVM-17.0.6-win64.exe` 后,双击进行安装,选择合适的安装位置,最重要的是勾选 `Add Path to Current User` 添加环境变量。
+ - 确认安装完成。
+ 2. 下载安装 Microsoft Visual C++ 可再发行程序包,解决潜在 .dll 丢失问题。
+ - [MSVC++ 14.40.33810.0 下载](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. 下载安装 Visual Studio 社区版以获取 MSVC++ 编译工具, 解决 LLVM 的头文件依赖问题。
+ - [Visual Studio 下载](https://visualstudio.microsoft.com/zh-hans/downloads/)
+ - 安装好 Visual Studio Installer 之后,下载 Visual Studio Community 2022
+ - 如下图点击`修改`按钮,找到`使用C++的桌面开发`项,勾选下载
+ 4. 下载安装 [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
+4. 双击 `start.bat` 打开训练推理 WebUI 管理界面. 如有需要,可照下列提示修改`API_FLAGS`.
+
+!!! info "可选"
+
+ 想启动 推理 WebUI 界面?编辑项目根目录下的 `API_FLAGS.txt`, 前三行修改成如下格式:
+ ```
+ --infer
+ # --api
+ # --listen ...
+ ...
+ ```
+
+!!! info "可选"
+
+ 想启动 API 服务器?编辑项目根目录下的 `API_FLAGS.txt`, 前三行修改成如下格式:
+ ```
+ # --infer
+ --api
+ --listen ...
+ ...
+ ```
+
+!!! info "可选"
+
+ 双击 `run_cmd.bat` 进入本项目的 conda/python 命令行环境
+
+## Linux 配置
+
+有关详细信息,请参见 [pyproject.toml](../../pyproject.toml)。
+```bash
+# 创建一个 python 3.10 虚拟环境, 你也可以用 virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# 安装 pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+
+# (Ubuntu / Debian 用户) 安装 sox + ffmpeg
+apt install libsox-dev ffmpeg
+
+# (Ubuntu / Debian 用户) 安装 pyaudio
+apt install build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0
+
+# 安装 fish-speech
+pip3 install -e .[stable]
+```
+
+## macos 配置
+
+如果您想在 MPS 上进行推理,请添加 `--device mps` 标志。
+有关推理速度的比较,请参考 [此 PR](https://github.com/fishaudio/fish-speech/pull/461#issuecomment-2284277772)。
+
+!!! 警告
+ `compile` 选项在 Apple Silicon 设备上尚未正式支持,因此推理速度没有提升的保证。
+
+```bash
+# create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+# install pytorch
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+# install fish-speech
+pip install -e .[stable]
+```
+
+## Docker 配置
+
+1. 安装 NVIDIA Container Toolkit:
+
+ Docker 如果想使用 GPU 进行模型训练和推理,需要安装 NVIDIA Container Toolkit :
+
+ 对于 Ubuntu 用户:
+
+ ```bash
+ # 添加远程仓库
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
+ && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+ # 安装 nvidia-container-toolkit
+ sudo apt-get update
+ sudo apt-get install -y nvidia-container-toolkit
+ # 重启 Docker 服务
+ sudo systemctl restart docker
+ ```
+
+ 对于使用其他 Linux 发行版的用户,安装指南请参考:[NVIDIA Container Toolkit Install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)。
+
+ 注:对于中国大陆的用户,您可能需要使用代理来完成相关工具的安装。
+
+2. 拉取并运行 fish-speech 镜像
+
+ ```shell
+ # 拉取镜像
+ docker pull fishaudio/fish-speech:latest-dev
+ # 运行镜像
+ docker run -it \
+ --name fish-speech \
+ --gpus all \
+ -p 7860:7860 \
+ fishaudio/fish-speech:latest-dev \
+ zsh
+ # 如果需要使用其他端口,请修改 -p 参数为 YourPort:7860
+ ```
+
+3. 下载模型依赖
+
+ 确保您在 docker 容器内的终端,然后再从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。
+
+ ```bash
+ huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+ ```
+
+ 对于中国大陆用户,可以通过镜像站下载。
+
+ ```bash
+ HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+ ```
+
+4. 配置环境变量,访问 WebUI
+
+ 在 docker 容器内的终端,输入 `export GRADIO_SERVER_NAME="0.0.0.0"` ,从而让外部可以访问 docker 内的 gradio 服务。
+ 接着在 docker 容器内的终端,输入 `python tools/run_webui.py` 即可开启 WebUI 服务。
+
+ 如果是 WSL 或者是 MacOS ,访问 [http://localhost:7860](http://localhost:7860) 即可打开 WebUI 界面。
+
+ 如果是部署在服务器上,更换 localhost 为您的服务器 ip 即可。
+
+## 更新日志
+
+- 2024/09/10: 更新了 Fish-Speech 到 1.4, 增加了数据集大小, quantizer n_groups 4 -> 8.
+- 2024/07/02: 更新了 Fish-Speech 到 1.2 版本,移除 VITS Decoder,同时极大幅度提升 zero-shot 能力.
+- 2024/05/10: 更新了 Fish-Speech 到 1.1 版本,引入了 VITS Decoder 来降低口胡和提高音色相似度.
+- 2024/04/22: 完成了 Fish-Speech 1.0 版本, 大幅修改了 VQGAN 和 LLAMA 模型.
+- 2023/12/28: 添加了 `lora` 微调支持.
+- 2023/12/27: 添加了 `gradient checkpointing`, `causual sampling` 和 `flash-attn` 支持.
+- 2023/12/19: 更新了 Webui 和 HTTP API.
+- 2023/12/18: 更新了微调文档和相关例子.
+- 2023/12/17: 更新了 `text2semantic` 模型, 支持无音素模式.
+- 2023/12/13: 测试版发布, 包含 VQGAN 模型和一个基于 LLAMA 的语言模型 (只支持音素).
+
+## 致谢
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/zh/inference.md b/docs/zh/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..4106b9bc8007db32c51bc11902bc86a59a8e9777
--- /dev/null
+++ b/docs/zh/inference.md
@@ -0,0 +1,143 @@
+# 推理
+
+推理支持命令行, http api, 以及 webui 三种方式.
+
+!!! note
+ 总的来说, 推理分为几个部分:
+
+ 1. 给定一段 ~10 秒的语音, 将它用 VQGAN 编码.
+ 2. 将编码后的语义 token 和对应文本输入语言模型作为例子.
+ 3. 给定一段新文本, 让模型生成对应的语义 token.
+ 4. 将生成的语义 token 输入 VQGAN 解码, 生成对应的语音.
+
+## 命令行推理
+
+从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+对于中国大陆用户,可使用 mirror 下载。
+
+```bash
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
+```
+
+### 1. 从语音生成 prompt:
+
+!!! note
+ 如果你打算让模型随机选择音色, 你可以跳过这一步.
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+你应该能得到一个 `fake.npy` 文件.
+
+### 2. 从文本生成语义 token:
+
+```bash
+python tools/llama/generate.py \
+ --text "要转换的文本" \
+ --prompt-text "你的参考文本" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.5" \
+ --num-samples 2 \
+ --compile
+```
+
+该命令会在工作目录下创建 `codes_N` 文件, 其中 N 是从 0 开始的整数.
+
+!!! note
+ 您可能希望使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 个 token/秒 -> ~500 个 token/秒).
+ 对应的, 如果你不打算使用加速, 你可以注释掉 `--compile` 参数.
+
+!!! info
+ 对于不支持 bf16 的 GPU, 你可能需要使用 `--half` 参数.
+
+### 3. 从语义 token 生成人声:
+
+#### VQGAN 解码
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+## HTTP API 推理
+
+运行以下命令来启动 HTTP 服务:
+
+```bash
+python -m tools.api_server \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+> 如果你想要加速推理,可以加上`--compile`参数。
+
+推荐中国大陆用户运行以下命令来启动 HTTP 服务:
+```bash
+HF_ENDPOINT=https://hf-mirror.com python -m ...(同上)
+```
+
+随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API.
+
+下面是使用`tools/api_client.py`发送请求的示例。
+
+```bash
+python -m tools.api_client \
+ --text "要输入的文本" \
+ --reference_audio "参考音频路径" \
+ --reference_text "参考音频的文本内容" \
+ --streaming True
+```
+
+上面的命令表示按照参考音频的信息,合成所需的音频并流式返回.
+
+下面的示例展示了, 可以一次使用**多个** `参考音频路径` 和 `参考音频的文本内容`。在命令里用空格隔开即可。
+```bash
+python -m tools.api_client \
+ --text "要输入的文本" \
+ --reference_audio "参考音频路径1" "参考音频路径2" \
+ --reference_text "参考音频的文本内容1" "参考音频的文本内容2"\
+ --streaming False \
+ --output "generated" \
+ --format "mp3"
+```
+
+上面的命令表示按照多个参考音频的信息,合成所需的`MP3`格式音频,并保存为当前目录的`generated.mp3`文件。
+
+还可以用`--reference_id`(仅能用一个)来代替`--reference_audio`和`--reference_text`, 前提是在项目根目录下创建`references/`文件夹,
+里面放上任意对音频与标注文本。 目前支持的参考音频最多加起来总时长90s。
+
+!!! info
+ 要了解有关可用参数的更多信息,可以使用命令`python -m tools.api_client -h`
+
+## GUI 推理
+[下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases)
+
+## WebUI 推理
+
+你可以使用以下命令来启动 WebUI:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+> 如果你想要加速推理,可以加上`--compile`参数。
+
+!!! note
+ 你可以提前将label文件和参考音频文件保存到主目录下的 `references` 文件夹(需要自行创建),这样你可以直接在WebUI中调用它们。
+
+!!! note
+ 你可以使用 Gradio 环境变量, 如 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` 来配置 WebUI.
+
+祝大家玩得开心!
diff --git a/docs/zh/samples.md b/docs/zh/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..bdcce108de01aa6b0da5fdea737fdd41450f6396
--- /dev/null
+++ b/docs/zh/samples.md
@@ -0,0 +1,225 @@
+# 例子
+
+v1.4 演示已更新至[此处](https://speech.fish.audio/samples/)。
+
+v1.2 的样本可以在 [Bilibili](https://www.bilibili.com/video/BV1wz421B71D/) 观看。
+
+以下样本来自 v1.1 版本的模型。
+
+## 中文句子 1
+```
+人间灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。
+流入她所注视的世间,也流入她如湖水般澄澈的目光。
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 纳西妲 (原神)
+
+
+
+
+ 钟离 (原神)
+
+
+
+
+ 芙宁娜 (原神)
+
+
+
+
+ 随机说话人 1
+ -
+
+
+
+ 随机说话人 2
+ -
+
+
+
+
+
+
+## 中文句子 2
+```
+你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么?
+我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊?
+你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊,
+搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗?
+一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊?
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 纳西妲 (原神)
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
+
+
+## 中文句子 3
+```
+大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练,
+我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。
+作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。
+```
+
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
+
+## 英文句子 1
+
+```
+In the realm of advanced technology, the evolution of artificial intelligence stands as a
+monumental achievement. This dynamic field, constantly pushing the boundaries of what
+machines can do, has seen rapid growth and innovation. From deciphering complex data
+patterns to driving cars autonomously, AI's applications are vast and diverse.
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人 1
+ -
+
+
+
+ 随机说话人 2
+ -
+
+
+
+
+
+## 英文句子 2
+```
+Hello everyone, I am an open-source text-to-speech model developed by
+Fish Audio. After training with 150,000 hours of data, I have become proficient
+in Chinese, Japanese, and English, and my language processing abilities
+are close to human level. My voice is capable of a wide range of expressions.
+As a model with only hundreds of millions of parameters, I believe community
+members can easily run and fine-tune me on their personal devices, allowing
+me to serve as your personal voice assistant.
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
+
+## 日文句子 1
+
+```
+先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を
+押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か
+ら自動運転車の操縦まで、AIの応用は広範囲に及びます。
+```
+
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人 1
+ -
+
+
+
+ 随机说话人 2
+ -
+
+
+
+
+
+## 日文句子 2
+```
+皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ
+キストから音声への変換モデルです。15万時間のデータトレーニングを経て、
+中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。
+声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ
+のメンバーが個人のデバイスで簡単に実行し、微調整することができると
+信じています。これにより、私を個人の音声アシスタントとして活用できます。
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
diff --git a/docs/zh/start_agent.md b/docs/zh/start_agent.md
new file mode 100644
index 0000000000000000000000000000000000000000..c93b9068fb74060c2a97c79e845d7435a0fd383c
--- /dev/null
+++ b/docs/zh/start_agent.md
@@ -0,0 +1,83 @@
+# 启动 Agent
+
+## 要求
+
+- GPU 显存: 至少 8GB(在量化的条件下),推荐 16GB 及以上
+- 硬盘使用量: 10GB
+
+## 下载模型
+
+你可以执行下面的语句来获取模型:
+
+```bash
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
+```
+
+如果你处于国内网络,首先执行:
+
+```bash
+export HF_ENDPOINT=https://hf-mirror.com
+```
+
+把他们放进名为 'checkpoints' 的文件夹内。
+
+你同样需要 fish-speech 的模型,关于如何获取 fish-speech 模型请查看[inference](inference.md)。
+
+完成后你的 checkpoints 文件夹中会有两个子文件夹:`checkpoints/fish-speech-1.4` 和 `checkpoints/fish-agent-v0.1-3b`。
+
+## Environment Prepare
+
+如果你已经有了 Fish-Speech 环境,你可以在安装下面的包的前提下直接使用:
+
+```bash
+pip install cachetools
+```
+
+!!! note
+请使用小于 3.12 的 python 版本使 compile 可用
+
+如果你没有 Fish-Speech 环境,请执行下面的语句来构造你的环境:
+
+```bash
+sudo apt-get install portaudio19-dev
+
+pip install -e .[stable]
+```
+
+## 链接 Agent.
+
+你需要使用以下指令来构建 fish-agent
+
+```bash
+python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+```
+
+`--compile`只能在小于 3.12 版本的 Python 使用,这个功能可以极大程度上提高生成速度。
+
+你需要哦注意 compile 需要进行一段时间.
+
+然后启动另一个终端并执行:
+
+```bash
+python -m tools.e2e_webui
+```
+
+这会在设备上创建一个 Gradio WebUI。
+
+每当进行第一轮对话的时候,模型需要 compile 一段时间,请耐心等待
+
+## Gradio Webui
+
+
+
+
+
+玩得开心!
+
+## Performance
+
+在我们的测试环境下, 4060 laptop GPU 只能刚刚运行该模型,只有大概 8 tokens/s。 4090 CPU 可以在编译后达到 95 tokens/s,我们推荐使用至少 4080 以上级别的 GPU 来达到较好体验。
+
+# About Agent
+
+该模型仍处于测试阶段。如果你发现了问题,请给我们提 issue 或者 pull request,我们非常感谢。
diff --git a/entrypoint.sh b/entrypoint.sh
new file mode 100755
index 0000000000000000000000000000000000000000..eb4564e090c977fce69efc689ae4d7381e43dd5b
--- /dev/null
+++ b/entrypoint.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+CUDA_ENABLED=${CUDA_ENABLED:-true}
+DEVICE=""
+
+if [ "${CUDA_ENABLED}" != "true" ]; then
+ DEVICE="--device cpu"
+fi
+
+exec python tools/run_webui.py ${DEVICE}
diff --git a/fish_speech.egg-info/PKG-INFO b/fish_speech.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..c7f90c5c84cc542f2f82e4036650e5356caf8361
--- /dev/null
+++ b/fish_speech.egg-info/PKG-INFO
@@ -0,0 +1,188 @@
+Metadata-Version: 2.1
+Name: fish-speech
+Version: 0.1.0
+Summary: Fish Speech
+Author-email: Lengyue
+License: CC BY-NC-SA 4.0
+Keywords: TTS,Speech
+Classifier: Programming Language :: Python :: 3
+Requires-Python: >=3.10
+Description-Content-Type: text/markdown
+License-File: LICENSE
+Requires-Dist: numpy<=1.26.4
+Requires-Dist: transformers>=4.45.2
+Requires-Dist: datasets==2.18.0
+Requires-Dist: lightning>=2.1.0
+Requires-Dist: hydra-core>=1.3.2
+Requires-Dist: tensorboard>=2.14.1
+Requires-Dist: natsort>=8.4.0
+Requires-Dist: einops>=0.7.0
+Requires-Dist: librosa>=0.10.1
+Requires-Dist: rich>=13.5.3
+Requires-Dist: gradio>5.0.0
+Requires-Dist: wandb>=0.15.11
+Requires-Dist: grpcio>=1.58.0
+Requires-Dist: kui>=1.6.0
+Requires-Dist: uvicorn>=0.30.0
+Requires-Dist: loguru>=0.6.0
+Requires-Dist: loralib>=0.1.2
+Requires-Dist: pyrootutils>=1.0.4
+Requires-Dist: vector_quantize_pytorch==1.14.24
+Requires-Dist: resampy>=0.4.3
+Requires-Dist: einx[torch]==0.2.2
+Requires-Dist: zstandard>=0.22.0
+Requires-Dist: pydub
+Requires-Dist: pyaudio
+Requires-Dist: faster_whisper
+Requires-Dist: modelscope==1.17.1
+Requires-Dist: funasr==1.1.5
+Requires-Dist: opencc-python-reimplemented==0.1.7
+Requires-Dist: silero-vad
+Requires-Dist: ormsgpack
+Requires-Dist: tiktoken>=0.8.0
+Requires-Dist: pydantic==2.9.2
+Requires-Dist: cachetools
+Provides-Extra: stable
+Requires-Dist: torch<=2.4.1; extra == "stable"
+Requires-Dist: torchaudio; extra == "stable"
+
+
+
Fish Speech
+
+**English** | [简体中文](docs/README.zh.md) | [Portuguese](docs/README.pt-BR.md) | [日本語](docs/README.ja.md) | [한국어](docs/README.ko.md)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+This codebase is released under Apache License and all model weights are released under CC-BY-NC-SA-4.0 License. Please refer to [LICENSE](LICENSE) for more details.
+
+---
+## Fish Agent
+We are very excited to announce that we have made our self-research agent demo open source, you can now try our agent demo online at [demo](https://fish.audio/demo/live) for instant English chat and English and Chinese chat locally by following the [docs](https://speech.fish.audio/start_agent/).
+
+You should mention that the content is released under a **CC BY-NC-SA 4.0 licence**. And the demo is an early alpha test version, the inference speed needs to be optimised, and there are a lot of bugs waiting to be fixed. If you've found a bug or want to fix it, we'd be very happy to receive an issue or a pull request.
+
+## Features
+### Fish Speech
+
+1. **Zero-shot & Few-shot TTS:** Input a 10 to 30-second vocal sample to generate high-quality TTS output. **For detailed guidelines, see [Voice Cloning Best Practices](https://docs.fish.audio/text-to-speech/voice-clone-best-practices).**
+
+2. **Multilingual & Cross-lingual Support:** Simply copy and paste multilingual text into the input box—no need to worry about the language. Currently supports English, Japanese, Korean, Chinese, French, German, Arabic, and Spanish.
+
+3. **No Phoneme Dependency:** The model has strong generalization capabilities and does not rely on phonemes for TTS. It can handle text in any language script.
+
+4. **Highly Accurate:** Achieves a low CER (Character Error Rate) and WER (Word Error Rate) of around 2% for 5-minute English texts.
+
+5. **Fast:** With fish-tech acceleration, the real-time factor is approximately 1:5 on an Nvidia RTX 4060 laptop and 1:15 on an Nvidia RTX 4090.
+
+6. **WebUI Inference:** Features an easy-to-use, Gradio-based web UI compatible with Chrome, Firefox, Edge, and other browsers.
+
+7. **GUI Inference:** Offers a PyQt6 graphical interface that works seamlessly with the API server. Supports Linux, Windows, and macOS. [See GUI](https://github.com/AnyaCoder/fish-speech-gui).
+
+8. **Deploy-Friendly:** Easily set up an inference server with native support for Linux, Windows and MacOS, minimizing speed loss.
+
+### Fish Agent
+1. **Completely End to End:** Automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
+
+2. **Timbre Control:** Can use reference audio to control the speech timbre.
+
+3. **Emotional:** The model can generate speech with strong emotion.
+
+## Disclaimer
+
+We do not hold any responsibility for any illegal usage of the codebase. Please refer to your local laws about DMCA and other related laws.
+
+## Online Demo
+
+[Fish Audio](https://fish.audio)
+
+[Fish Agent](https://fish.audio/demo/live)
+
+## Quick Start for Local Inference
+
+[inference.ipynb](/inference.ipynb)
+
+## Videos
+
+#### V1.4 Demo Video: [Youtube](https://www.youtube.com/watch?v=Ghc8cJdQyKQ)
+
+## Documents
+
+- [English](https://speech.fish.audio/)
+- [中文](https://speech.fish.audio/zh/)
+- [日本語](https://speech.fish.audio/ja/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/)
+
+## Samples (2024/10/02 V1.4)
+
+- [English](https://speech.fish.audio/samples/)
+- [中文](https://speech.fish.audio/zh/samples/)
+- [日本語](https://speech.fish.audio/ja/samples/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/)
+
+## Credits
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## Tech Report (V1.4)
+```bibtex
+@misc{fish-speech-v1.4,
+ title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
+ author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
+ year={2024},
+ eprint={2411.01156},
+ archivePrefix={arXiv},
+ primaryClass={cs.SD},
+ url={https://arxiv.org/abs/2411.01156},
+}
+```
+
+## Sponsor
+
+
+
diff --git a/fish_speech.egg-info/SOURCES.txt b/fish_speech.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..58be9d7b55734ebdad0e0f85962394300c578e42
--- /dev/null
+++ b/fish_speech.egg-info/SOURCES.txt
@@ -0,0 +1,177 @@
+.dockerignore
+.gitignore
+.pre-commit-config.yaml
+.project-root
+.readthedocs.yaml
+API_FLAGS.txt
+LICENSE
+README.md
+docker-compose.dev.yml
+dockerfile
+dockerfile.dev
+entrypoint.sh
+inference.ipynb
+install_env.bat
+mkdocs.yml
+pyproject.toml
+pyrightconfig.json
+run_cmd.bat
+start.bat
+.github/pull_request_template.md
+.github/ISSUE_TEMPLATE/bug_report.yml
+.github/ISSUE_TEMPLATE/config.yml
+.github/ISSUE_TEMPLATE/feature_request.yml
+.github/workflows/build-docker-image.yml
+.github/workflows/docs.yml
+.github/workflows/stale.yml
+docs/CNAME
+docs/README.ja.md
+docs/README.ko.md
+docs/README.pt-BR.md
+docs/README.zh.md
+docs/requirements.txt
+docs/assets/figs/VS_1.jpg
+docs/assets/figs/VS_1_pt-BR.png
+docs/assets/figs/agent_gradio.png
+docs/assets/figs/diagram.png
+docs/assets/figs/diagrama.png
+docs/assets/figs/logo-circle.png
+docs/en/finetune.md
+docs/en/index.md
+docs/en/inference.md
+docs/en/samples.md
+docs/en/start_agent.md
+docs/ja/finetune.md
+docs/ja/index.md
+docs/ja/inference.md
+docs/ja/samples.md
+docs/ja/start_agent.md
+docs/ko/finetune.md
+docs/ko/index.md
+docs/ko/inference.md
+docs/ko/samples.md
+docs/ko/start_agent.md
+docs/pt/finetune.md
+docs/pt/index.md
+docs/pt/inference.md
+docs/pt/samples.md
+docs/pt/start_agent.md
+docs/stylesheets/extra.css
+docs/zh/finetune.md
+docs/zh/index.md
+docs/zh/inference.md
+docs/zh/samples.md
+docs/zh/start_agent.md
+fish_speech/conversation.py
+fish_speech/scheduler.py
+fish_speech/tokenizer.py
+fish_speech/train.py
+fish_speech.egg-info/PKG-INFO
+fish_speech.egg-info/SOURCES.txt
+fish_speech.egg-info/dependency_links.txt
+fish_speech.egg-info/requires.txt
+fish_speech.egg-info/top_level.txt
+fish_speech/callbacks/__init__.py
+fish_speech/callbacks/grad_norm.py
+fish_speech/configs/base.yaml
+fish_speech/configs/firefly_gan_vq.yaml
+fish_speech/configs/text2semantic_finetune.yaml
+fish_speech/configs/lora/r_8_alpha_16.yaml
+fish_speech/datasets/concat_repeat.py
+fish_speech/datasets/semantic.py
+fish_speech/datasets/vqgan.py
+fish_speech/datasets/protos/text-data.proto
+fish_speech/datasets/protos/text_data_pb2.py
+fish_speech/datasets/protos/text_data_stream.py
+fish_speech/i18n/README.md
+fish_speech/i18n/__init__.py
+fish_speech/i18n/core.py
+fish_speech/i18n/scan.py
+fish_speech/i18n/locale/en_US.json
+fish_speech/i18n/locale/es_ES.json
+fish_speech/i18n/locale/ja_JP.json
+fish_speech/i18n/locale/ko_KR.json
+fish_speech/i18n/locale/pt_BR.json
+fish_speech/i18n/locale/zh_CN.json
+fish_speech/models/text2semantic/__init__.py
+fish_speech/models/text2semantic/lit_module.py
+fish_speech/models/text2semantic/llama.py
+fish_speech/models/text2semantic/lora.py
+fish_speech/models/vqgan/__init__.py
+fish_speech/models/vqgan/utils.py
+fish_speech/models/vqgan/modules/firefly.py
+fish_speech/models/vqgan/modules/fsq.py
+fish_speech/text/__init__.py
+fish_speech/text/clean.py
+fish_speech/text/spliter.py
+fish_speech/text/chn_text_norm/.gitignore
+fish_speech/text/chn_text_norm/README.md
+fish_speech/text/chn_text_norm/__init__.py
+fish_speech/text/chn_text_norm/basic_class.py
+fish_speech/text/chn_text_norm/basic_constant.py
+fish_speech/text/chn_text_norm/basic_util.py
+fish_speech/text/chn_text_norm/cardinal.py
+fish_speech/text/chn_text_norm/date.py
+fish_speech/text/chn_text_norm/digit.py
+fish_speech/text/chn_text_norm/fraction.py
+fish_speech/text/chn_text_norm/money.py
+fish_speech/text/chn_text_norm/percentage.py
+fish_speech/text/chn_text_norm/telephone.py
+fish_speech/text/chn_text_norm/text.py
+fish_speech/utils/__init__.py
+fish_speech/utils/braceexpand.py
+fish_speech/utils/context.py
+fish_speech/utils/file.py
+fish_speech/utils/instantiators.py
+fish_speech/utils/logger.py
+fish_speech/utils/logging_utils.py
+fish_speech/utils/rich_utils.py
+fish_speech/utils/spectrogram.py
+fish_speech/utils/utils.py
+fish_speech/webui/launch_utils.py
+fish_speech/webui/manage.py
+fish_speech/webui/css/style.css
+fish_speech/webui/html/footer.html
+fish_speech/webui/js/animate.js
+tools/api_client.py
+tools/api_server.py
+tools/download_models.py
+tools/e2e_webui.py
+tools/extract_model.py
+tools/file.py
+tools/fish_e2e.py
+tools/run_webui.py
+tools/schema.py
+tools/smart_pad.py
+tools/whisper_asr.py
+tools/inference_engine/__init__.py
+tools/inference_engine/reference_loader.py
+tools/inference_engine/utils.py
+tools/inference_engine/vq_manager.py
+tools/llama/build_dataset.py
+tools/llama/eval_in_context.py
+tools/llama/generate.py
+tools/llama/merge_lora.py
+tools/llama/quantize.py
+tools/llama/rebuild_tokenizer.py
+tools/sensevoice/README.md
+tools/sensevoice/__init__.py
+tools/sensevoice/auto_model.py
+tools/sensevoice/fun_asr.py
+tools/sensevoice/vad_utils.py
+tools/server/api_utils.py
+tools/server/exception_handler.py
+tools/server/inference.py
+tools/server/model_manager.py
+tools/server/model_utils.py
+tools/server/views.py
+tools/server/agent/__init__.py
+tools/server/agent/generate.py
+tools/server/agent/generation_utils.py
+tools/server/agent/pre_generation_utils.py
+tools/vqgan/create_train_split.py
+tools/vqgan/extract_vq.py
+tools/vqgan/inference.py
+tools/webui/__init__.py
+tools/webui/inference.py
+tools/webui/variables.py
\ No newline at end of file
diff --git a/fish_speech.egg-info/dependency_links.txt b/fish_speech.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/fish_speech.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/fish_speech.egg-info/requires.txt b/fish_speech.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6faa1a0a6fa3c43135cc68c6fe3995cdd890d352
--- /dev/null
+++ b/fish_speech.egg-info/requires.txt
@@ -0,0 +1,37 @@
+numpy<=1.26.4
+transformers>=4.45.2
+datasets==2.18.0
+lightning>=2.1.0
+hydra-core>=1.3.2
+tensorboard>=2.14.1
+natsort>=8.4.0
+einops>=0.7.0
+librosa>=0.10.1
+rich>=13.5.3
+gradio>5.0.0
+wandb>=0.15.11
+grpcio>=1.58.0
+kui>=1.6.0
+uvicorn>=0.30.0
+loguru>=0.6.0
+loralib>=0.1.2
+pyrootutils>=1.0.4
+vector_quantize_pytorch==1.14.24
+resampy>=0.4.3
+einx[torch]==0.2.2
+zstandard>=0.22.0
+pydub
+pyaudio
+faster_whisper
+modelscope==1.17.1
+funasr==1.1.5
+opencc-python-reimplemented==0.1.7
+silero-vad
+ormsgpack
+tiktoken>=0.8.0
+pydantic==2.9.2
+cachetools
+
+[stable]
+torch<=2.4.1
+torchaudio
diff --git a/fish_speech.egg-info/top_level.txt b/fish_speech.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..46946680c8304866f1d6fe4ea04cfaf65b4a68e6
--- /dev/null
+++ b/fish_speech.egg-info/top_level.txt
@@ -0,0 +1,2 @@
+fish_speech
+tools
diff --git a/fish_speech/__pycache__/conversation.cpython-310.pyc b/fish_speech/__pycache__/conversation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a80e19402f524db0b9c77ad02eec5ea97f1cbed
Binary files /dev/null and b/fish_speech/__pycache__/conversation.cpython-310.pyc differ
diff --git a/fish_speech/__pycache__/tokenizer.cpython-310.pyc b/fish_speech/__pycache__/tokenizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75245a1a9a607f56fe93261dce9523ca449f1f40
Binary files /dev/null and b/fish_speech/__pycache__/tokenizer.cpython-310.pyc differ
diff --git a/fish_speech/callbacks/__init__.py b/fish_speech/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbcf3f33656d180ca87cd14a21ede1544e5a61a3
--- /dev/null
+++ b/fish_speech/callbacks/__init__.py
@@ -0,0 +1,3 @@
+from .grad_norm import GradNormMonitor
+
+__all__ = ["GradNormMonitor"]
diff --git a/fish_speech/callbacks/grad_norm.py b/fish_speech/callbacks/grad_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc95ef2a3723323b2d976001ed1e3c79c00b21a
--- /dev/null
+++ b/fish_speech/callbacks/grad_norm.py
@@ -0,0 +1,113 @@
+from typing import Optional, Union
+
+import lightning.pytorch as pl
+import torch
+from lightning import LightningModule, Trainer
+from lightning.pytorch.callbacks import Callback
+from torch import Tensor, nn
+from torch.utils._foreach_utils import (
+ _group_tensors_by_device_and_dtype,
+ _has_foreach_support,
+)
+
+
+@torch.no_grad()
+def grad_norm(
+ parameters: Union[Tensor, list[Tensor]],
+ norm_type: float = 2.0,
+) -> float:
+ """
+ Returns the norm of the gradients of the given parameters.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ norm_type (float): type of the used p-norm.
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """ # noqa: E501
+
+ if isinstance(parameters, Tensor):
+ parameters = [parameters]
+
+ grads = [p.grad for p in parameters if p.grad is not None]
+ if len(grads) == 0:
+ return None
+
+ first_device = grads[0].device
+ grouped_grads: dict[
+ tuple[torch.device, torch.dtype], list[list[Tensor]]
+ ] = _group_tensors_by_device_and_dtype(
+ [[g.detach() for g in grads]]
+ ) # type: ignore[assignment]
+
+ norms = []
+ for (device, _), ([grads], _) in grouped_grads.items():
+ if _has_foreach_support(grads, device=device):
+ norms.extend(torch._foreach_norm(grads, norm_type))
+ else:
+ norms.extend([torch.norm(g, norm_type) for g in grads])
+
+ return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
+
+
+class GradNormMonitor(Callback):
+ """
+ Callback that computes the gradient norm of the model parameters.
+ """
+
+ def __init__(
+ self,
+ norm_type: float = 2.0,
+ logging_interval: str = "step",
+ sub_module: Optional[Union[str, list[str]]] = None,
+ ) -> None:
+ """
+ Args:
+ norm_type (float): type of the used p-norm.
+ logging_interval (str): "step" or "epoch".
+ """
+ super().__init__()
+
+ self.norm_type = norm_type
+ self.logging_interval = logging_interval
+ self.sub_module = sub_module
+
+ def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
+ """
+ Computes the gradient norm of the model parameters and logs it to the logger.
+
+ Args:
+ trainer (Trainer): The trainer object
+ model (LightningModule): The current lightningModule
+ """
+
+ lightning_model = model
+
+ if self.sub_module is None:
+ return self.log_sub_module_grad_norm(lightning_model, model, "")
+
+ sub_modules = self.sub_module
+ if isinstance(sub_modules, str):
+ sub_modules = [sub_modules]
+
+ for sub_module in sub_modules:
+ self.log_sub_module_grad_norm(
+ lightning_model, getattr(model, sub_module), f"/{sub_module}"
+ )
+
+ def log_sub_module_grad_norm(
+ self, lightning_model: LightningModule, model: nn.Module, path: str
+ ) -> None:
+ grad_norm_val = grad_norm(model.parameters(), self.norm_type)
+ if grad_norm_val is None:
+ return
+
+ on_step = self.logging_interval == "step"
+ lightning_model.log(
+ f"train{path}/grad_norm",
+ grad_norm_val,
+ on_step=on_step,
+ on_epoch=not on_step,
+ )
diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..99e6dab54d3f57bce4f6d29a9129a19a523cad75
--- /dev/null
+++ b/fish_speech/configs/base.yaml
@@ -0,0 +1,87 @@
+# Base configuration for training a model
+paths:
+ run_dir: results/${project}
+ ckpt_dir: ${paths.run_dir}/checkpoints
+
+hydra:
+ run:
+ dir: ${paths.run_dir}
+
+# Lightning Trainer
+trainer:
+ _target_: lightning.pytorch.trainer.Trainer
+
+ default_root_dir: ${paths.run_dir}
+ accelerator: gpu
+ num_nodes: 1
+ devices: auto
+ strategy:
+ _target_: lightning.pytorch.strategies.DDPStrategy
+ process_group_backend: nccl # This should be override when training on windows
+
+ precision: bf16-mixed
+
+ # disable validation by epoch end
+ check_val_every_n_epoch: null
+ val_check_interval: 5000
+ max_steps: 100_000
+
+ # Use torch.backends.cudnn.benchmark to speed up training
+ benchmark: true
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
+ dirpath: ${paths.ckpt_dir}
+ filename: "step_{step:09d}"
+ save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+ save_top_k: 5 # save 5 latest checkpoints
+ monitor: step # use step to monitor checkpoints
+ mode: max # save the latest checkpoint with the highest global_step
+ every_n_epochs: null # don't save checkpoints by epoch end
+ every_n_train_steps: 5000 # save checkpoints every 5000 steps
+ auto_insert_metric_name: false
+
+ model_summary:
+ _target_: lightning.pytorch.callbacks.ModelSummary
+ max_depth: 2 # the maximum depth of layer nesting that the summary will include
+
+ learning_rate_monitor:
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
+ logging_interval: step
+ log_momentum: false
+
+ grad_norm_monitor:
+ _target_: fish_speech.callbacks.GradNormMonitor
+ norm_type: 2
+ logging_interval: step
+
+# Logger
+logger:
+ tensorboard:
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
+ save_dir: "${paths.run_dir}/tensorboard/"
+ name: null
+ log_graph: false
+ default_hp_metric: true
+ prefix: ""
+
+ # wandb:
+ # _target_: lightning.pytorch.loggers.wandb.WandbLogger
+ # # name: "" # name of the run (normally generated by wandb)
+ # save_dir: "${paths.run_dir}"
+ # offline: False
+ # id: null # pass correct id to resume experiment!
+ # anonymous: null # enable anonymous logging
+ # project: "fish-speech"
+ # log_model: False # upload lightning ckpts
+ # prefix: "" # a string to put at the beginning of metric keys
+ # # entity: "" # set to name of your wandb team
+ # group: ""
+ # tags: ["vq", "hq", "finetune"]
+ # job_type: ""
+
+# Loop
+train: true
+test: false
diff --git a/fish_speech/configs/firefly_gan_vq.yaml b/fish_speech/configs/firefly_gan_vq.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..10aa8d4a522f0859ed8f541f5d48672d84b39c8f
--- /dev/null
+++ b/fish_speech/configs/firefly_gan_vq.yaml
@@ -0,0 +1,33 @@
+_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
+spec_transform:
+ _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
+ sample_rate: 44100
+ n_mels: 160
+ n_fft: 2048
+ hop_length: 512
+ win_length: 2048
+backbone:
+ _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
+ input_channels: 160
+ depths: [3, 3, 9, 3]
+ dims: [128, 256, 384, 512]
+ drop_path_rate: 0.2
+ kernel_size: 7
+head:
+ _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
+ hop_length: 512
+ upsample_rates: [8, 8, 2, 2, 2] # aka. strides
+ upsample_kernel_sizes: [16, 16, 4, 4, 4]
+ resblock_kernel_sizes: [3, 7, 11]
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+ num_mels: 512
+ upsample_initial_channel: 512
+ pre_conv_kernel_size: 13
+ post_conv_kernel_size: 13
+quantizer:
+ _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
+ input_dim: 512
+ n_groups: 8
+ n_codebooks: 1
+ levels: [8, 5, 5, 5]
+ downsample_factor: [2, 2]
diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aecc4d9766a18fe31c55941e01b1f590c95e77c9
--- /dev/null
+++ b/fish_speech/configs/lora/r_8_alpha_16.yaml
@@ -0,0 +1,4 @@
+_target_: fish_speech.models.text2semantic.lora.LoraConfig
+r: 8
+lora_alpha: 16
+lora_dropout: 0.01
diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f4c1993023099e122fc9e004bda55ec075ed5e1b
--- /dev/null
+++ b/fish_speech/configs/text2semantic_finetune.yaml
@@ -0,0 +1,83 @@
+defaults:
+ - base
+ - _self_
+
+project: text2semantic_finetune_dual_ar
+max_length: 4096
+pretrained_ckpt_path: checkpoints/fish-speech-1.4
+
+# Lightning Trainer
+trainer:
+ accumulate_grad_batches: 1
+ gradient_clip_val: 1.0
+ gradient_clip_algorithm: "norm"
+ max_steps: 1000
+ precision: bf16-true
+ limit_val_batches: 10
+ val_check_interval: 100
+
+# Dataset Configuration
+tokenizer:
+ _target_: transformers.AutoTokenizer.from_pretrained
+ pretrained_model_name_or_path: ${pretrained_ckpt_path}
+
+# Dataset Configuration
+train_dataset:
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ causal: true
+ max_length: ${max_length}
+ use_speaker: false
+ interactive_prob: 0.7
+
+val_dataset:
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ causal: true
+ max_length: ${max_length}
+ use_speaker: false
+ interactive_prob: 0.7
+
+data:
+ _target_: fish_speech.datasets.semantic.SemanticDataModule
+ train_dataset: ${train_dataset}
+ val_dataset: ${val_dataset}
+ num_workers: 4
+ batch_size: 8
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+
+# Model Configuration
+model:
+ _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
+ model:
+ _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
+ path: ${pretrained_ckpt_path}
+ load_weights: true
+ max_length: ${max_length}
+ lora_config: null
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ _partial_: true
+ lr: 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.95]
+ eps: 1e-5
+
+ lr_scheduler:
+ _target_: torch.optim.lr_scheduler.LambdaLR
+ _partial_: true
+ lr_lambda:
+ _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
+ _partial_: true
+ num_warmup_steps: 10
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ every_n_train_steps: ${trainer.val_check_interval}
diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..20d8ab3292351a305b4572570b507ceb08da952c
--- /dev/null
+++ b/fish_speech/conversation.py
@@ -0,0 +1,267 @@
+from dataclasses import dataclass, field
+from typing import Literal
+
+import torch
+
+from .tokenizer import MODALITY_TOKENS, FishTokenizer
+
+CODEBOOK_PAD_TOKEN_ID = 0
+
+
+@dataclass(kw_only=True)
+class BasePart:
+ pass
+
+
+@dataclass(kw_only=True)
+class VQPart(BasePart):
+ codes: torch.Tensor
+
+
+@dataclass(kw_only=True)
+class TextPart(BasePart):
+ text: str
+
+
+@dataclass(kw_only=True)
+class EncodedMessage:
+ tokens: torch.Tensor
+ labels: torch.Tensor
+ vq_mask_tokens: torch.Tensor | None = None
+ vq_mask_labels: torch.Tensor | None = None
+ vq_parts: list[torch.Tensor]
+ vq_require_losses: torch.Tensor | None = None
+
+
+@dataclass(kw_only=True)
+class Message:
+ role: Literal["system", "user", "assistant"]
+ parts: list[VQPart | TextPart] = field(default_factory=list)
+ add_im_start: bool = True
+ add_im_end: bool = True
+ cal_loss: bool = False
+ modality: Literal["text", "voice", "interleave"] | None = None
+
+ # By default, ignore the loss of the auto-generated im_start token
+ ignore_im_start_loss: bool = True
+
+ def encode(
+ self: "Message",
+ tokenizer: FishTokenizer,
+ ) -> EncodedMessage:
+ all_tokens = []
+ all_labels = []
+
+ # Multi-modal tokens
+ vq_parts = []
+ vq_masks = []
+
+ parts = self.parts.copy()
+ if self.add_im_start:
+ modality_token = MODALITY_TOKENS[self.modality] if self.modality else ""
+ parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n{modality_token}"))
+
+ if self.add_im_end:
+ parts.append(TextPart(text="<|im_end|>"))
+
+ for part in parts:
+ if isinstance(part, TextPart):
+ tokens = torch.tensor(
+ tokenizer.encode(part.text),
+ dtype=torch.int,
+ )
+ elif isinstance(part, VQPart):
+ curr_codes = part.codes.clone()
+ tokens = torch.tensor(
+ [
+ tokenizer.semantic_id_to_token_id[i.item()]
+ for i in curr_codes[0].int()
+ ],
+ dtype=torch.int,
+ )
+ vq_parts.append(curr_codes)
+ else:
+ raise ValueError(f"Unsupported part type: {type(part)}")
+
+ all_tokens.append(tokens)
+ if isinstance(part, VQPart):
+ vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
+ else:
+ vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
+
+ if self.cal_loss:
+ all_labels.append(tokens.clone())
+ else:
+ all_labels.append(torch.full_like(tokens, -100))
+
+ tokens = torch.cat(all_tokens, dim=0)
+ labels = torch.cat(all_labels, dim=0)
+ vq_masks = torch.cat(vq_masks, dim=0)
+
+ assert tokens.shape == labels.shape == vq_masks.shape
+
+ if self.ignore_im_start_loss and self.add_im_start:
+ labels[: len(all_tokens[0])] = -100
+
+ return EncodedMessage(
+ tokens=tokens,
+ labels=labels,
+ vq_parts=vq_parts,
+ vq_mask_tokens=vq_masks,
+ vq_mask_labels=vq_masks,
+ )
+
+
+@dataclass
+class Conversation:
+ messages: list[Message]
+
+ def __init__(self: "Conversation", messages: list[Message] | None = None):
+ self.messages = messages or []
+
+ def encode(
+ self: "Conversation",
+ tokenizer: FishTokenizer,
+ add_shift: bool = True,
+ ignore_loss_tokens: list[str] = [],
+ ) -> EncodedMessage:
+ # Build the input_ids and labels
+ tokens = []
+ labels = []
+ vq_parts = []
+ vq_mask_tokens = []
+ vq_mask_labels = []
+ vq_require_losses = []
+ ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
+
+ for message in self.messages:
+ encoded = message.encode(
+ tokenizer,
+ )
+ tokens.append(encoded.tokens)
+ labels.append(encoded.labels)
+ vq_parts.extend(encoded.vq_parts)
+ vq_mask_tokens.append(encoded.vq_mask_tokens)
+ vq_mask_labels.append(encoded.vq_mask_labels)
+ vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
+
+ tokens = torch.cat(tokens, dim=0)
+ labels = torch.cat(labels, dim=0)
+ vq_mask_tokens = torch.cat(vq_mask_tokens, dim=0)
+ vq_mask_labels = torch.cat(vq_mask_labels, dim=0)
+ vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
+
+ if add_shift:
+ tokens = tokens[:-1]
+ labels = labels[1:]
+ vq_mask_tokens = vq_mask_tokens[:-1]
+ vq_mask_labels = vq_mask_labels[1:]
+
+ for i in ignore_loss_token_ids:
+ assert i != -100 and i is not None
+ labels[labels == i] = -100
+
+ assert tokens.dtype in [
+ torch.int,
+ torch.long,
+ ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
+
+ return EncodedMessage(
+ tokens=tokens,
+ labels=labels,
+ vq_parts=vq_parts,
+ vq_mask_tokens=vq_mask_tokens,
+ vq_mask_labels=vq_mask_labels,
+ vq_require_losses=vq_require_losses,
+ )
+
+ def encode_for_inference(
+ self: "Conversation",
+ tokenizer: FishTokenizer,
+ num_codebooks: int,
+ ) -> EncodedMessage:
+ # self.visualize(tokenizer)
+
+ encoded = self.encode(tokenizer, add_shift=False)
+ tokens = encoded.tokens
+ values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
+ values[0] = tokens
+
+ if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
+ return values
+
+ vq_parts = encoded.vq_parts
+ vq_parts = [part.to(values.device) for part in vq_parts]
+ vq_parts = torch.cat(vq_parts, dim=1)
+ values[0, encoded.vq_mask_tokens] = vq_parts[0] + tokenizer.semantic_begin_id
+ values[1:, encoded.vq_mask_tokens] = vq_parts
+
+ return values
+
+ def visualize(
+ self: "Conversation",
+ tokenizer: FishTokenizer,
+ ignore_loss_tokens: list[str] = [],
+ ):
+ encoded = self.encode(
+ tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
+ )
+
+ # Colors for alternating tokens
+ colors = {
+ "blue": "\033[94m", # Light blue
+ "cyan": "\033[96m", # Cyan
+ "green": "\033[92m", # Light green
+ "dark_green": "\033[32m", # Dark green
+ }
+ blue_idx = 0
+ green_idx = 0
+
+ def print_in_blue(x):
+ nonlocal blue_idx
+ color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
+ print(f"{color}{x}\033[0m", end="")
+ blue_idx += 1
+
+ def print_in_green(x):
+ nonlocal green_idx
+ color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
+ print(f"{color}{x}\033[0m", end="")
+ green_idx += 1
+
+ for tok, lab in zip(encoded.tokens, encoded.labels):
+ val = tokenizer.decode([tok])
+
+ if lab == -100:
+ print_in_green(val)
+ else:
+ print_in_blue(val)
+
+ print()
+
+ def append(self: "Conversation", message: Message):
+ self.messages.append(message)
+
+
+if __name__ == "__main__":
+ message0 = Message(
+ role="user",
+ parts=[
+ TextPart(text="Hello, how are you?"),
+ VQPart(codes=torch.zeros((4, 10))),
+ ],
+ cal_loss=False,
+ )
+
+ message1 = Message(
+ role="assistant",
+ parts=[TextPart(text="I'm fine, thank you.")],
+ cal_loss=True,
+ )
+ conversation = Conversation([message0, message1])
+ tokenizer = FishTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
+ conversation.visualize(tokenizer)
+
+ encoded = conversation.encode(tokenizer)
+ print(encoded)
+ print(tokenizer.batch_decode(encoded.tokens))
diff --git a/fish_speech/datasets/concat_repeat.py b/fish_speech/datasets/concat_repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa596b95a572ee15c5570cbdb792c9a78e62dfa
--- /dev/null
+++ b/fish_speech/datasets/concat_repeat.py
@@ -0,0 +1,53 @@
+import bisect
+import random
+from typing import Iterable
+
+from torch.utils.data import Dataset, IterableDataset
+
+
+class ConcatRepeatDataset(Dataset):
+ datasets: list[Dataset]
+ cumulative_sizes: list[int]
+ repeats: list[int]
+
+ @staticmethod
+ def cumsum(sequence, repeats):
+ r, s = [], 0
+ for dataset, repeat in zip(sequence, repeats):
+ l = len(dataset) * repeat
+ r.append(l + s)
+ s += l
+ return r
+
+ def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
+ super().__init__()
+
+ self.datasets = list(datasets)
+ self.repeats = repeats
+
+ assert len(self.datasets) > 0, "datasets should not be an empty iterable"
+ assert len(self.datasets) == len(
+ repeats
+ ), "datasets and repeats should have the same length"
+
+ for d in self.datasets:
+ assert not isinstance(
+ d, IterableDataset
+ ), "ConcatRepeatDataset does not support IterableDataset"
+
+ self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+
+ dataset = self.datasets[dataset_idx]
+
+ return dataset[sample_idx % len(dataset)]
diff --git a/fish_speech/datasets/protos/text-data.proto b/fish_speech/datasets/protos/text-data.proto
new file mode 100644
index 0000000000000000000000000000000000000000..5eb26d94aa3be1e21066f2bf38c90d54e85a8379
--- /dev/null
+++ b/fish_speech/datasets/protos/text-data.proto
@@ -0,0 +1,24 @@
+syntax = "proto3";
+
+package text_data;
+
+message Semantics {
+ repeated uint32 values = 1;
+}
+
+message Sentence {
+ repeated string texts = 1;
+ repeated Semantics semantics = 3;
+}
+
+message TextData {
+ string source = 1;
+ string name = 2;
+ repeated Sentence sentences = 4;
+}
+
+message SampledData {
+ string source = 1;
+ string name = 2;
+ repeated Sentence samples = 3;
+}
diff --git a/fish_speech/datasets/protos/text_data_pb2.py b/fish_speech/datasets/protos/text_data_pb2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfce0e8be59fc51e68999ef137e1fd0e4adc0d7e
--- /dev/null
+++ b/fish_speech/datasets/protos/text_data_pb2.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: text-data.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
+)
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals["_SEMANTICS"]._serialized_start = 30
+ _globals["_SEMANTICS"]._serialized_end = 57
+ _globals["_SENTENCE"]._serialized_start = 59
+ _globals["_SENTENCE"]._serialized_end = 125
+ _globals["_TEXTDATA"]._serialized_start = 127
+ _globals["_TEXTDATA"]._serialized_end = 207
+ _globals["_SAMPLEDDATA"]._serialized_start = 209
+ _globals["_SAMPLEDDATA"]._serialized_end = 290
+# @@protoc_insertion_point(module_scope)
diff --git a/fish_speech/datasets/protos/text_data_stream.py b/fish_speech/datasets/protos/text_data_stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec3c25bcd764e8245de47dcdf9686d6adfb5a107
--- /dev/null
+++ b/fish_speech/datasets/protos/text_data_stream.py
@@ -0,0 +1,36 @@
+import struct
+
+from .text_data_pb2 import TextData
+
+
+def read_pb_stream(f):
+ while True:
+ buf = f.read(4)
+ if len(buf) == 0:
+ break
+ size = struct.unpack("I", buf)[0]
+ buf = f.read(size)
+ text_data = TextData()
+ text_data.ParseFromString(buf)
+ yield text_data
+
+
+def write_pb_stream(f, text_data):
+ buf = text_data.SerializeToString()
+ f.write(struct.pack("I", len(buf)))
+ f.write(buf)
+
+
+def pack_pb_stream(text_data):
+ buf = text_data.SerializeToString()
+ return struct.pack("I", len(buf)) + buf
+
+
+def split_pb_stream(f):
+ while True:
+ head = f.read(4)
+ if len(head) == 0:
+ break
+ size = struct.unpack("I", head)[0]
+ buf = f.read(size)
+ yield head + buf
diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c64e01077ae253bdc4e4d9cd948f8fb50df7418
--- /dev/null
+++ b/fish_speech/datasets/semantic.py
@@ -0,0 +1,496 @@
+import random
+from dataclasses import dataclass
+from itertools import chain
+from pathlib import Path
+from random import Random
+from typing import Optional, Union
+
+import numpy as np
+import pyarrow.parquet as pq
+import torch
+import torch.nn.functional as F
+from datasets.download.streaming_download_manager import xopen
+from huggingface_hub import HfApi
+from lightning import LightningDataModule
+from torch.distributed import get_rank, get_world_size, is_initialized
+from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.datasets.protos.text_data_pb2 import SampledData
+from fish_speech.datasets.protos.text_data_stream import read_pb_stream
+from fish_speech.text.clean import clean_text
+from fish_speech.utils import RankedLogger
+from fish_speech.utils.braceexpand import braceexpand
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def split_by_rank_worker(files):
+ # We need to know the total number of devices
+ # to split the data properly
+
+ total_devices = 1
+ if is_initialized():
+ total_devices = get_world_size()
+
+ worker_info = get_worker_info()
+ if worker_info is not None:
+ total_devices *= worker_info.num_workers
+
+ if len(files) < total_devices:
+ # Repeat the files N times to match the number of devices
+ files = files * (total_devices // len(files) + 1)
+
+ # DDP
+ if is_initialized():
+ files = files[get_rank() :: get_world_size()]
+
+ # Split by worker
+ if worker_info is not None:
+ files = files[worker_info.id :: worker_info.num_workers]
+
+ return files
+
+
+class AutoTextSemanticInstructionDataset(IterableDataset):
+ """
+ Auto Augment Dataset by Speaker
+
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+ 2. Automatically normalize the text
+
+ For interactive mode, we use the following format (multiple sequences):
+ [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+
+ For non-interactive mode, we use the following format (one long sequence):
+ [INST] text [/INST] ...
+ """
+
+ def __init__(
+ self,
+ proto_files: list[str],
+ seed: int = 42,
+ interactive_prob: float = 0.5,
+ max_length: int = 1024,
+ tokenizer: AutoTokenizer = None,
+ use_speaker: bool | float = True,
+ causal: bool = True,
+ num_codebooks: Optional[int] = None,
+ skip_text_prob: float = 0.0,
+ ):
+ """
+ Args:
+ proto_files: proto buf files if using local data
+ seed: random seed
+ interactive_prob: probability to use interactive mode
+ max_length: max length of the text
+ tokenizer: tokenizer
+ use_speaker: include speaker information in the prompt
+ causal: use causal sampling when using local data, disable will lead to random sampling
+ num_codebooks: number of codebooks, if None, it will be automatically detected
+ skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
+ """
+
+ super().__init__()
+
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+
+ self.seed = seed
+ self.max_length = max_length
+ self.tokenizer = tokenizer
+ self.interactive_prob = interactive_prob
+ self.use_speaker = use_speaker
+ self.proto_files = proto_files
+ self.causal = causal
+ self.num_codebooks = num_codebooks
+ self.skip_text_prob = skip_text_prob
+
+ self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
+ self.groups = None
+
+ def init_mock_data_server(self):
+ if self.groups is not None:
+ return
+
+ # Expand the proto files
+ expanded_proto_files = []
+ for filename in self.proto_files:
+ for i in braceexpand(filename):
+ i = Path(i)
+ if i.is_file():
+ expanded_proto_files.append(i)
+ elif i.is_dir():
+ expanded_proto_files.extend(i.rglob("*.proto"))
+ expanded_proto_files.extend(i.rglob("*.protos"))
+ else:
+ raise ValueError(f"{i} is not a file or directory")
+
+ expanded_proto_files = sorted(expanded_proto_files)
+ Random(self.seed).shuffle(expanded_proto_files)
+
+ self.groups = []
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
+ log.info(
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
+ )
+
+ count = 0
+ for filename in shard_proto_files:
+ with open(filename, "rb") as f:
+ for text_data in read_pb_stream(f):
+ self.groups.append(text_data)
+ count += 1
+
+ log.info(f"Read total {count} groups of data")
+
+ # Shuffle the lines
+ Random(self.seed).shuffle(self.groups)
+ self.group_weights = [len(i.sentences) for i in self.groups]
+
+ def __iter__(self):
+ while True:
+ yield self.augment()
+
+ def tokenize_sentence(self, sentence: str):
+ sentence = clean_text(sentence)
+ tokens = self.tokenizer.encode(
+ f"{sentence}",
+ max_length=10**6,
+ add_special_tokens=False,
+ truncation=False,
+ )
+ return sentence, len(tokens)
+
+ def sample_data(self):
+ if self.groups is None:
+ self.init_mock_data_server()
+
+ # Shuffle unique lines, estimate that each sample is at least 20 tokens
+ num_samples = self.max_length // 20
+
+ # choice group based on their number of samples
+ group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
+
+ if self.causal:
+ # Sample in order
+ if num_samples >= len(group.sentences):
+ samples = group.sentences
+ else:
+ begin = random.randint(0, len(group.sentences) - num_samples)
+ samples = group.sentences[begin : begin + num_samples]
+ else:
+ samples = random.choices(
+ group.sentences, k=min(num_samples, len(group.sentences))
+ )
+
+ return SampledData(
+ source=group.source,
+ name=group.name,
+ samples=samples,
+ )
+
+ def augment(self):
+ final_text, final_semantic = [], []
+ response = self.sample_data()
+ if len(response.samples) == 0:
+ # Invalid group
+ return None
+
+ samples = list(response.samples)
+ idx = 0
+ use_interactive = random.random() < self.interactive_prob
+
+ if use_interactive is False:
+ # Random sample based on speaker using a truncated normal distribution
+ a = torch.tensor([0], dtype=torch.float32)
+ torch.nn.init.trunc_normal_(
+ a,
+ mean=self.max_length // 2,
+ std=self.max_length // 4,
+ a=10,
+ b=self.max_length,
+ )
+ remaining_tokens = a.long().item() - 4
+ else:
+ remaining_tokens = self.max_length
+
+ # Use speaker
+ if isinstance(self.use_speaker, float):
+ use_speaker = random.random() < self.use_speaker
+ else:
+ use_speaker = self.use_speaker
+
+ all_tokens, all_labels = [], []
+ while remaining_tokens > 0 and len(samples) > 0:
+ sentence = samples.pop(0)
+
+ text = random.choice(sentence.texts)
+ text, length = self.tokenize_sentence(text)
+ remaining_tokens -= length + len(sentence.semantics[0].values)
+
+ if use_interactive is False:
+ final_text.append(text)
+ final_semantic.append(sentence.semantics)
+ else:
+ # For interactive mode, we only apply speaker for the first sentence
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+ tokens, labels = self.pack_sentences(
+ sentences=[text],
+ semantics=[sentence.semantics],
+ speaker=response.name if use_speaker else None,
+ skip_text=random.random() < self.skip_text_prob,
+ )
+
+ all_tokens.append(tokens)
+ all_labels.append(labels)
+
+ idx += 1
+
+ if use_interactive is False:
+ tokens, labels = self.pack_sentences(
+ final_text,
+ semantics=final_semantic,
+ speaker=response.name if use_speaker else None,
+ )
+ all_tokens.append(tokens)
+ all_labels.append(labels)
+
+ tokens = torch.cat(all_tokens, dim=1)
+ labels = torch.cat(all_labels, dim=1)
+
+ # Verify that the length is correct
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+ data = {"tokens": tokens, "labels": labels}
+
+ return data
+
+ def pack_sentences(
+ self,
+ sentences: list[str],
+ semantics: list,
+ speaker: Optional[str] = None,
+ skip_text: bool = False,
+ ):
+ if speaker is None:
+ speaker = "assistant"
+
+ cated_sentences = " ".join(sentences)
+ if skip_text:
+ cated_sentences = "<|skip_text|>"
+
+ final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
+ final_text = final_text + f"<|im_start|>{speaker}\n"
+
+ encoded = self.tokenizer.encode(
+ final_text,
+ add_special_tokens=False,
+ truncation=False,
+ max_length=10**6,
+ )
+ semantic_length = sum([len(i[0].values) for i in semantics])
+ prompt_length = len(encoded)
+ num_codebooks = (
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
+ )
+
+ # Pack the tokens and semantics (add and to semantic tokens)
+ tokens = (
+ encoded
+ + [self.semantic_token_id] * semantic_length
+ + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
+ )
+
+ # Codebook bos/padding: 0, eos: 1
+ codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
+ for segment in semantics:
+ for book_idx, book in zip(range(num_codebooks), segment):
+ for j in book.values:
+ codes[book_idx].append(int(j) + 1)
+
+ for book in codes:
+ book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
+
+ tokens = [tokens] + codes
+
+ tokens = torch.tensor(tokens, dtype=torch.long)
+ labels = tokens.clone()
+
+ if skip_text:
+ # If text is not provided, the sentence is used for condition only, all labels are -100
+ torch.fill_(labels, -100)
+ return tokens, labels
+
+ # Mask out the tokens for semantic, predict semantic tokens only
+ # Since we don't mask out the input tokens, the language modeling still works
+ labels[1:, :prompt_length] = -100
+
+ tokens = tokens[:, :-1]
+ labels = labels[:, 1:]
+
+ # Verify the padding is correct, and the last token is eos
+ assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
+ assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
+
+ return tokens, labels
+
+
+@dataclass
+class TextDataCollator:
+ tokenizer: AutoTokenizer
+ max_length: int = 1024
+
+ def __call__(self, examples):
+ if "negative_tokens" in examples:
+ positive_examples = []
+ negative_examples = []
+
+ for i in examples:
+ positive_examples.append(
+ {
+ "tokens": i["tokens"],
+ "labels": i["labels"],
+ }
+ )
+ negative_examples.append(
+ {
+ "tokens": i["negative_tokens"],
+ "labels": i["negative_labels"],
+ }
+ )
+
+ examples = positive_examples + negative_examples
+
+ return self.batchify(examples)
+
+ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
+ tokens, attention_masks, labels = [], [], []
+
+ # Calculate the max length
+ max_tokens_length = 0
+ for example in examples:
+ max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
+ max_tokens_length = min(max_tokens_length, self.max_length)
+
+ for example in examples:
+ _tokens = example[tokens_key][:, :max_tokens_length]
+ _labels = example[labels_key][:, :max_tokens_length]
+ _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
+ tokens_length = _tokens.size(1)
+ _attention_mask[:tokens_length] = False
+
+ assert tokens_length == _labels.size(
+ 1
+ ), f"{tokens_length} != {_labels.size(1)}"
+
+ if tokens_length < max_tokens_length:
+ _tokens = F.pad(
+ _tokens,
+ (0, max_tokens_length - tokens_length),
+ value=self.tokenizer.eos_token_id,
+ )
+ _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
+ _labels = F.pad(
+ _labels, (0, max_tokens_length - _labels.size(1)), value=-100
+ )
+
+ tokens.append(_tokens)
+ attention_masks.append(_attention_mask)
+ labels.append(_labels)
+
+ tokens = torch.stack(tokens, dim=0)
+ attention_masks = torch.stack(attention_masks, dim=0)
+ labels = torch.stack(labels, dim=0)
+
+ return {
+ "inputs": tokens,
+ "attention_masks": attention_masks,
+ "labels": labels,
+ }
+
+
+class InterleaveDataset(IterableDataset):
+ def __init__(
+ self,
+ datasets: list[IterableDataset],
+ probabilities: list[float],
+ seed: int = 42,
+ ):
+ super().__init__()
+
+ self.datasets = datasets
+ self.probabilities = probabilities
+ self.seed = seed
+
+ def __iter__(self):
+ rng = np.random.default_rng(self.seed)
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+ while True:
+ # Random choice one
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
+ dataset_iterator = dataset_iterators[dataset_idx]
+
+ try:
+ yield next(dataset_iterator)
+ except StopIteration:
+ # Exhausted, create a new iterator
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+ yield next(dataset_iterators[dataset_idx])
+
+
+class SemanticDataModule(LightningDataModule):
+ def __init__(
+ self,
+ train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+ val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+ batch_size: int = 32,
+ tokenizer: AutoTokenizer = None,
+ max_length: int = 1024,
+ num_workers: int = 4,
+ ):
+ super().__init__()
+
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.batch_size = batch_size
+ self.tokenizer = tokenizer
+ self.max_length = max_length
+ self.num_workers = num_workers
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.batch_size,
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+
+if __name__ == "__main__":
+ from tqdm import tqdm
+
+ ds = AutoTextSemanticInstructionDataset(
+ ["data/protos"],
+ tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
+ use_speaker=False,
+ interactive_prob=1.0,
+ skip_text_prob=0.5,
+ )
+
+ for i in ds:
+ print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
+ # i["labels"][0][i["labels"][0] == -100] = 0
+ # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
+ break
diff --git a/fish_speech/datasets/vqgan.py b/fish_speech/datasets/vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..a45583d22efb0feb9dc1e823bae1ef74534b299e
--- /dev/null
+++ b/fish_speech/datasets/vqgan.py
@@ -0,0 +1,147 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import librosa
+import numpy as np
+import torch
+from lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset
+
+from fish_speech.utils import RankedLogger
+
+logger = RankedLogger(__name__, rank_zero_only=False)
+
+
+class VQGANDataset(Dataset):
+ def __init__(
+ self,
+ filelist: str,
+ sample_rate: int = 32000,
+ hop_length: int = 640,
+ slice_frames: Optional[int] = None,
+ ):
+ super().__init__()
+
+ filelist = Path(filelist)
+ root = filelist.parent
+
+ self.files = [
+ root / line.strip()
+ for line in filelist.read_text(encoding="utf-8").splitlines()
+ if line.strip()
+ ]
+ self.sample_rate = sample_rate
+ self.hop_length = hop_length
+ self.slice_frames = slice_frames
+
+ def __len__(self):
+ return len(self.files)
+
+ def get_item(self, idx):
+ file = self.files[idx]
+
+ audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
+
+ # Slice audio and features
+ if (
+ self.slice_frames is not None
+ and audio.shape[0] > self.slice_frames * self.hop_length
+ ):
+ start = np.random.randint(
+ 0, audio.shape[0] - self.slice_frames * self.hop_length
+ )
+ audio = audio[start : start + self.slice_frames * self.hop_length]
+
+ if len(audio) == 0:
+ return None
+
+ max_value = np.abs(audio).max()
+ if max_value > 1.0:
+ audio = audio / max_value
+
+ return {
+ "audio": torch.from_numpy(audio),
+ }
+
+ def __getitem__(self, idx):
+ try:
+ return self.get_item(idx)
+ except Exception as e:
+ import traceback
+
+ traceback.print_exc()
+ logger.error(f"Error loading {self.files[idx]}: {e}")
+ return None
+
+
+@dataclass
+class VQGANCollator:
+ def __call__(self, batch):
+ batch = [x for x in batch if x is not None]
+
+ audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
+ audio_maxlen = audio_lengths.max()
+
+ # Rounds up to nearest multiple of 2 (audio_lengths)
+ audios = []
+ for x in batch:
+ audios.append(
+ torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
+ )
+
+ return {
+ "audios": torch.stack(audios),
+ "audio_lengths": audio_lengths,
+ }
+
+
+class VQGANDataModule(LightningDataModule):
+ def __init__(
+ self,
+ train_dataset: VQGANDataset,
+ val_dataset: VQGANDataset,
+ batch_size: int = 32,
+ num_workers: int = 4,
+ val_batch_size: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.batch_size = batch_size
+ self.val_batch_size = val_batch_size or batch_size
+ self.num_workers = num_workers
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=VQGANCollator(),
+ num_workers=self.num_workers,
+ shuffle=True,
+ persistent_workers=True,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.val_batch_size,
+ collate_fn=VQGANCollator(),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+
+if __name__ == "__main__":
+ dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
+ dataloader = DataLoader(
+ dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
+ )
+
+ for batch in dataloader:
+ print(batch["audios"].shape)
+ print(batch["features"].shape)
+ print(batch["audio_lengths"])
+ print(batch["feature_lengths"])
+ break
diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..700902b09db20911ef1ad678cbdce5644b84aea2
--- /dev/null
+++ b/fish_speech/i18n/README.md
@@ -0,0 +1,27 @@
+## i18n Folder Attribution
+
+The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
+
+### fish_speech/i18n/core.py
+
+**Related code from RVC:**
+[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
+
+**Initial commit:**
+add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
+
+**Initial author:**
+[@L4Ph](https://github.com/L4Ph)
+
+### fish_speech/i18n/scan.py
+
+**Related code from RVC:**
+[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
+
+**Initial commit:**
+File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
+
+**Initial author:**
+[@towzeur](https://github.com/towzeur)
+
+We appreciate the contributions of the RVC project and its authors.
diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..981dbb3b3ecf28043ec9ff5757f947182821a246
--- /dev/null
+++ b/fish_speech/i18n/__init__.py
@@ -0,0 +1,3 @@
+from .core import i18n
+
+__all__ = ["i18n"]
diff --git a/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc b/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0627c071d991b544a3e0afbe4737cd98081b7e8c
Binary files /dev/null and b/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/i18n/__pycache__/core.cpython-310.pyc b/fish_speech/i18n/__pycache__/core.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..085270e739a310ec9a8151e167b44260d332bd85
Binary files /dev/null and b/fish_speech/i18n/__pycache__/core.cpython-310.pyc differ
diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f793ec95669228f7f4e8f9a7a5fe38da85c74bd
--- /dev/null
+++ b/fish_speech/i18n/core.py
@@ -0,0 +1,40 @@
+import json
+import locale
+from pathlib import Path
+
+I18N_FILE_PATH = Path(__file__).parent / "locale"
+DEFAULT_LANGUAGE = "en_US"
+
+
+def load_language_list(language):
+ with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
+ language_list = json.load(f)
+
+ return language_list
+
+
+class I18nAuto:
+ def __init__(self):
+ i18n_file = Path(".locale")
+
+ if i18n_file.exists():
+ with open(i18n_file, "r", encoding="utf-8") as f:
+ language = f.read().strip()
+ else:
+ # getlocale can't identify the system's language ((None, None))
+ language = locale.getdefaultlocale()[0]
+
+ if (I18N_FILE_PATH / f"{language}.json").exists() is False:
+ language = DEFAULT_LANGUAGE
+
+ self.language = language
+ self.language_map = load_language_list(language)
+
+ def __call__(self, key):
+ return self.language_map.get(key, key)
+
+ def __repr__(self):
+ return "Use Language: " + self.language
+
+
+i18n = I18nAuto()
diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json
new file mode 100644
index 0000000000000000000000000000000000000000..d36c774313628fe9d4ee60e816f404c09935e655
--- /dev/null
+++ b/fish_speech/i18n/locale/en_US.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Accumulate Gradient Batches",
+ "Add to Processing Area": "Add to Processing Area",
+ "Added path successfully!": "Added path successfully!",
+ "Advanced Config": "Advanced Config",
+ "Base LLAMA Model": "Base LLAMA Model",
+ "Batch Inference": "Batch Inference",
+ "Batch Size": "Batch Size",
+ "Changing with the Model Path": "Changing with the Model Path",
+ "Chinese": "Chinese",
+ "Compile Model": "Compile Model",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
+ "Copy": "Copy",
+ "Data Preprocessing": "Data Preprocessing",
+ "Data Preprocessing Path": "Data Preprocessing Path",
+ "Data Source": "Data Source",
+ "Decoder Model Config": "Decoder Model Config",
+ "Decoder Model Path": "Decoder Model Path",
+ "Disabled": "Disabled",
+ "Enable Reference Audio": "Enable Reference Audio",
+ "English": "English",
+ "Error Message": "Error Message",
+ "File Preprocessing": "File Preprocessing",
+ "Generate": "Generate",
+ "Generated Audio": "Generated Audio",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
+ "Infer interface is closed": "Infer interface is closed",
+ "Inference Configuration": "Inference Configuration",
+ "Inference Server Configuration": "Inference Server Configuration",
+ "Inference Server Error": "Inference Server Error",
+ "Inferring interface is launched at {}": "Inferring interface is launched at {}",
+ "Initial Learning Rate": "Initial Learning Rate",
+ "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
+ "Input Text": "Input Text",
+ "Invalid path: {}": "Invalid path: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
+ "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
+ "Japanese": "Japanese",
+ "LLAMA Configuration": "LLAMA Configuration",
+ "LLAMA Model Config": "LLAMA Model Config",
+ "LLAMA Model Path": "LLAMA Model Path",
+ "Labeling Device": "Labeling Device",
+ "LoRA Model to be merged": "LoRA Model to be merged",
+ "Maximum Audio Duration": "Maximum Audio Duration",
+ "Maximum Length per Sample": "Maximum Length per Sample",
+ "Maximum Training Steps": "Maximum Training Steps",
+ "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
+ "Merge": "Merge",
+ "Merge LoRA": "Merge LoRA",
+ "Merge successfully": "Merge successfully",
+ "Minimum Audio Duration": "Minimum Audio Duration",
+ "Model Output Path": "Model Output Path",
+ "Model Size": "Model Size",
+ "Move": "Move",
+ "Move files successfully": "Move files successfully",
+ "No audio generated, please check the input text.": "No audio generated, please check the input text.",
+ "No selected options": "No selected options",
+ "Number of Workers": "Number of Workers",
+ "Open Inference Server": "Open Inference Server",
+ "Open Labeler WebUI": "Open Labeler WebUI",
+ "Open Tensorboard": "Open Tensorboard",
+ "Opened labeler in browser": "Opened labeler in browser",
+ "Optional Label Language": "Optional Label Language",
+ "Optional online ver": "Optional online ver",
+ "Output Path": "Output Path",
+ "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
+ "Precision": "Precision",
+ "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
+ "Put your text here.": "Put your text here.",
+ "Reference Audio": "Reference Audio",
+ "Reference Text": "Reference Text",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
+ "Remove Selected Data": "Remove Selected Data",
+ "Removed path successfully!": "Removed path successfully!",
+ "Repetition Penalty": "Repetition Penalty",
+ "Save model every n steps": "Save model every n steps",
+ "Select LLAMA ckpt": "Select LLAMA ckpt",
+ "Select VITS ckpt": "Select VITS ckpt",
+ "Select VQGAN ckpt": "Select VQGAN ckpt",
+ "Select source file processing method": "Select source file processing method",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
+ "Selected: {}": "Selected: {}",
+ "Speaker": "Speaker",
+ "Speaker is identified by the folder name": "Speaker is identified by the folder name",
+ "Start Training": "Start Training",
+ "Streaming Audio": "Streaming Audio",
+ "Streaming Generate": "Streaming Generate",
+ "Tensorboard Host": "Tensorboard Host",
+ "Tensorboard Log Path": "Tensorboard Log Path",
+ "Tensorboard Port": "Tensorboard Port",
+ "Tensorboard interface is closed": "Tensorboard interface is closed",
+ "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
+ "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
+ "Training Configuration": "Training Configuration",
+ "Training Error": "Training Error",
+ "Training stopped": "Training stopped",
+ "Type name of the speaker": "Type name of the speaker",
+ "Type the path or select from the dropdown": "Type the path or select from the dropdown",
+ "Use LoRA": "Use LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
+ "Use filelist": "Use filelist",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
+ "VITS Configuration": "VITS Configuration",
+ "VQGAN Configuration": "VQGAN Configuration",
+ "Validation Batch Size": "Validation Batch Size",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
+ "WebUI Host": "WebUI Host",
+ "WebUI Port": "WebUI Port",
+ "Whisper Model": "Whisper Model",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
+ "latest": "latest",
+ "new": "new",
+ "Realtime Transform Text": "Realtime Transform Text",
+ "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
+ "Text Normalization": "Text Normalization",
+ "Select Example Audio": "Select Example Audio"
+}
diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json
new file mode 100644
index 0000000000000000000000000000000000000000..7a4757967dd0fe3807ba4d354e75ad7a88eb510e
--- /dev/null
+++ b/fish_speech/i18n/locale/es_ES.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Acumular lotes de gradientes",
+ "Add to Processing Area": "Agregar al Área de Procesamiento",
+ "Added path successfully!": "¡Ruta agregada exitosamente!",
+ "Advanced Config": "Configuración Avanzada",
+ "Base LLAMA Model": "Modelo Base LLAMA",
+ "Batch Inference": "Inferencia por Lote",
+ "Batch Size": "Tamaño del Lote",
+ "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
+ "Chinese": "Chino",
+ "Compile Model": "Compilar Modelo",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
+ "Copy": "Copiar",
+ "Data Preprocessing": "Preprocesamiento de Datos",
+ "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
+ "Data Source": "Fuente de Datos",
+ "Decoder Model Config": "Configuración del modelo decodificador",
+ "Decoder Model Path": "Ruta del modelo decodificador",
+ "Disabled": "Desactivado",
+ "Enable Reference Audio": "Habilitar Audio de Referencia",
+ "English": "Inglés",
+ "Error Message": "Mensaje de Error",
+ "File Preprocessing": "Preprocesamiento de Archivos",
+ "Generate": "Generar",
+ "Generated Audio": "Audio Generado",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
+ "Infer interface is closed": "La interfaz de inferencia está cerrada",
+ "Inference Configuration": "Configuración de Inferencia",
+ "Inference Server Configuration": "Configuración del Servidor de Inferencia",
+ "Inference Server Error": "Error del Servidor de Inferencia",
+ "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
+ "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
+ "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
+ "Input Text": "Texto de Entrada",
+ "Invalid path: {}": "Ruta inválida: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
+ "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
+ "Japanese": "Japonés",
+ "LLAMA Configuration": "Configuración de LLAMA",
+ "LLAMA Model Config": "Configuración del Modelo LLAMA",
+ "LLAMA Model Path": "Ruta del Modelo LLAMA",
+ "Labeling Device": "Dispositivo de Etiquetado",
+ "LoRA Model to be merged": "Modelo LoRA a fusionar",
+ "Maximum Audio Duration": "Duración máxima de audio",
+ "Maximum Length per Sample": "Longitud Máxima por Muestra",
+ "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
+ "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
+ "Merge": "Fusionar",
+ "Merge LoRA": "Fusionar LoRA",
+ "Merge successfully": "Fusionado exitosamente",
+ "Minimum Audio Duration": "Duración mínima de audio",
+ "Model Output Path": "Ruta de Salida del Modelo",
+ "Model Size": "Tamaño del Modelo",
+ "Move": "Mover",
+ "Move files successfully": "Archivos movidos exitosamente",
+ "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
+ "No selected options": "No hay opciones seleccionadas",
+ "Number of Workers": "Número de Trabajadores",
+ "Open Inference Server": "Abrir Servidor de Inferencia",
+ "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
+ "Open Tensorboard": "Abrir Tensorboard",
+ "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
+ "Optional Label Language": "Idioma de Etiquetado Opcional",
+ "Optional online ver": "Ver en línea opcional",
+ "Output Path": "Ruta de Salida",
+ "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
+ "Precision": "Precisión",
+ "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
+ "Put your text here.": "Ponga su texto aquí.",
+ "Reference Audio": "Audio de Referencia",
+ "Reference Text": "Texto de Referencia",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
+ "Remove Selected Data": "Eliminar Datos Seleccionados",
+ "Removed path successfully!": "¡Ruta eliminada exitosamente!",
+ "Repetition Penalty": "Penalización por Repetición",
+ "Save model every n steps": "Guardar modelo cada n pasos",
+ "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
+ "Select VITS ckpt": "Seleccionar punto de control VITS",
+ "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
+ "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
+ "Selected: {}": "Seleccionado: {}",
+ "Speaker": "Hablante",
+ "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
+ "Start Training": "Iniciar Entrenamiento",
+ "Streaming Audio": "transmisión de audio",
+ "Streaming Generate": "síntesis en flujo",
+ "Tensorboard Host": "Host de Tensorboard",
+ "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
+ "Tensorboard Port": "Puerto de Tensorboard",
+ "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
+ "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
+ "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
+ "Training Configuration": "Configuración de Entrenamiento",
+ "Training Error": "Error de Entrenamiento",
+ "Training stopped": "Entrenamiento detenido",
+ "Type name of the speaker": "Escriba el nombre del hablante",
+ "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
+ "Use LoRA": "Usar LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
+ "Use filelist": "Usar lista de archivos",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
+ "VITS Configuration": "Configuración de VITS",
+ "VQGAN Configuration": "Configuración de VQGAN",
+ "Validation Batch Size": "Tamaño del Lote de Validación",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
+ "WebUI Host": "Host de WebUI",
+ "WebUI Port": "Puerto de WebUI",
+ "Whisper Model": "Modelo Whisper",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
+ "latest": "más reciente",
+ "new": "nuevo",
+ "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
+ "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
+ "Text Normalization": "Normalización de Texto",
+ "Select Example Audio": "Selecionar áudio de exemplo"
+}
diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json
new file mode 100644
index 0000000000000000000000000000000000000000..863b8b0b41da7e504ac0dcc4abf707f1f71a53fa
--- /dev/null
+++ b/fish_speech/i18n/locale/ja_JP.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
+ "Accumulate Gradient Batches": "勾配バッチの累積",
+ "Add to Processing Area": "処理エリアに追加",
+ "Added path successfully!": "パスの追加に成功しました!",
+ "Advanced Config": "詳細設定",
+ "Base LLAMA Model": "基本LLAMAモデル",
+ "Batch Inference": "バッチ推論",
+ "Batch Size": "バッチサイズ",
+ "Changing with the Model Path": "モデルのパスに伴って変化する",
+ "Chinese": "中国語",
+ "Compile Model": "モデルのコンパイル",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
+ "Copy": "コピー",
+ "Data Preprocessing": "データ前処理",
+ "Data Preprocessing Path": "データ前処理パス",
+ "Data Source": "データソース",
+ "Decoder Model Config": "デコーダーモデルの構成",
+ "Decoder Model Path": "デコーダーモデルのパス",
+ "Disabled": "無効",
+ "Enable Reference Audio": "リファレンスオーディオを有効にする",
+ "English": "英語",
+ "Error Message": "エラーメッセージ",
+ "File Preprocessing": "文書前处理",
+ "Generate": "生成",
+ "Generated Audio": "生成されたオーディオ",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
+ "Infer interface is closed": "推論インターフェースが閉じられています",
+ "Inference Configuration": "推論設定",
+ "Inference Server Configuration": "推論サーバー設定",
+ "Inference Server Error": "推論サーバーエラー",
+ "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
+ "Initial Learning Rate": "初期学習率",
+ "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
+ "Input Text": "入力テキスト",
+ "Invalid path: {}": "無効なパス: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
+ "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
+ "Japanese": "日本語",
+ "LLAMA Configuration": "LLAMA設定",
+ "LLAMA Model Config": "LLAMAモデル設定",
+ "LLAMA Model Path": "LLAMAモデルパス",
+ "Labeling Device": "ラベリングデバイス",
+ "LoRA Model to be merged": "マージするLoRAモデル",
+ "Maximum Audio Duration": "最大オーディオの長さ",
+ "Maximum Length per Sample": "サンプルあたりの最大長",
+ "Maximum Training Steps": "最大トレーニングステップ数",
+ "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
+ "Merge": "マージ",
+ "Merge LoRA": "LoRAのマージ",
+ "Merge successfully": "マージに成功しました",
+ "Minimum Audio Duration": "最小オーディオの長さ",
+ "Model Output Path": "モデル出力パス",
+ "Model Size": "モデルサイズ",
+ "Move": "移動",
+ "Move files successfully": "ファイルの移動に成功しました",
+ "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
+ "No selected options": "選択されたオプションはありません",
+ "Number of Workers": "ワーカー数",
+ "Open Inference Server": "推論サーバーを開く",
+ "Open Labeler WebUI": "ラベラーWebUIを開く",
+ "Open Tensorboard": "Tensorboardを開く",
+ "Opened labeler in browser": "ブラウザでラベラーを開きました",
+ "Optional Label Language": "オプションのラベル言語",
+ "Optional online ver": "オプションのオンラインバージョン",
+ "Output Path": "出力パス",
+ "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
+ "Precision": "精度",
+ "Probability of applying Speaker Condition": "話者条件を適用する確率",
+ "Put your text here.": "ここにテキストを入力してください。",
+ "Reference Audio": "リファレンスオーディオ",
+ "Reference Text": "リファレンステキスト",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
+ "Remove Selected Data": "選択したデータを削除",
+ "Removed path successfully!": "パスの削除に成功しました!",
+ "Repetition Penalty": "反復ペナルティ",
+ "Save model every n steps": "nステップごとにモデルを保存",
+ "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
+ "Select VITS ckpt": "VITS チェックポイントを選択",
+ "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
+ "Select source file processing method": "ソースファイルの処理方法を選択",
+ "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
+ "Selected: {}": "選択済み: {}",
+ "Speaker": "話者",
+ "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
+ "Start Training": "トレーニング開始",
+ "Streaming Audio": "ストリーミングオーディオ",
+ "Streaming Generate": "ストリーミング合成",
+ "Tensorboard Host": "Tensorboardホスト",
+ "Tensorboard Log Path": "Tensorboardログパス",
+ "Tensorboard Port": "Tensorboardポート",
+ "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
+ "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
+ "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
+ "Training Configuration": "トレーニング設定",
+ "Training Error": "トレーニングエラー",
+ "Training stopped": "トレーニングが停止しました",
+ "Type name of the speaker": "話者の名前を入力",
+ "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
+ "Use LoRA": "LoRAを使用",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
+ "Use filelist": "ファイルリストを使用",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
+ "VITS Configuration": "VITS の構成",
+ "VQGAN Configuration": "VQGAN の構成",
+ "Validation Batch Size": "検証バッチサイズ",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
+ "WebUI Host": "WebUIホスト",
+ "WebUI Port": "WebUIポート",
+ "Whisper Model": "Whisperモデル",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
+ "latest": "最新",
+ "new": "新規",
+ "Realtime Transform Text": "リアルタイム変換テキスト",
+ "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
+ "Text Normalization": "テキスト正規化",
+ "Select Example Audio": "サンプル音声を選択"
+}
diff --git a/fish_speech/i18n/locale/ko_KR.json b/fish_speech/i18n/locale/ko_KR.json
new file mode 100644
index 0000000000000000000000000000000000000000..180263874b476059870035d4c2b74ce5fa553a8a
--- /dev/null
+++ b/fish_speech/i18n/locale/ko_KR.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
+ "Accumulate Gradient Batches": "그라디언트 배치 누적",
+ "Add to Processing Area": "처리 영역에 추가",
+ "Added path successfully!": "경로가 성공적으로 추가되었습니다!",
+ "Advanced Config": "고급 설정",
+ "Base LLAMA Model": "기본 LLAMA 모델",
+ "Batch Inference": "배치 추론",
+ "Batch Size": "배치 크기",
+ "Changing with the Model Path": "모델 경로에 따라 변경 중",
+ "Chinese": "중국어",
+ "Compile Model": "모델 컴파일",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
+ "Copy": "복사",
+ "Data Preprocessing": "데이터 전처리",
+ "Data Preprocessing Path": "데이터 전처리 경로",
+ "Data Source": "데이터 소스",
+ "Decoder Model Config": "디코더 모델 설정",
+ "Decoder Model Path": "디코더 모델 경로",
+ "Disabled": "비활성화 됨",
+ "Enable Reference Audio": "참고 음성 활성화",
+ "English": "영어",
+ "Error Message": "오류 메시지",
+ "File Preprocessing": "파일 전처리",
+ "Generate": "생성",
+ "Generated Audio": "생성된 오디오",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
+ "Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
+ "Inference Configuration": "추론 설정",
+ "Inference Server Configuration": "추론 서버 설정",
+ "Inference Server Error": "추론 서버 오류",
+ "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
+ "Initial Learning Rate": "초기 학습률",
+ "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
+ "Input Text": "입력 텍스트",
+ "Invalid path: {}": "유효하지 않은 경로: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
+ "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
+ "Japanese": "일본어",
+ "LLAMA Configuration": "LLAMA 설정",
+ "LLAMA Model Config": "LLAMA 모델 설정",
+ "LLAMA Model Path": "LLAMA 모델 경로",
+ "Labeling Device": "라벨링 장치",
+ "LoRA Model to be merged": "병합할 LoRA 모델",
+ "Maximum Audio Duration": "최대 오디오 길이",
+ "Maximum Length per Sample": "샘플당 최대 길이",
+ "Maximum Training Steps": "최대 학습 단계",
+ "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
+ "Merge": "병합",
+ "Merge LoRA": "LoRA 병합",
+ "Merge successfully": "성공적으로 병합 되었습니다.",
+ "Minimum Audio Duration": "최소 오디오 길이",
+ "Model Output Path": "모델 출력 경로",
+ "Model Size": "모델 크기",
+ "Move": "이동",
+ "Move files successfully": "파일이 성공적으로 이동되었습니다.",
+ "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
+ "No selected options": "옵션이 선택되지 않았습니다.",
+ "Number of Workers": "작업자 수",
+ "Open Inference Server": "추론 서버 열기",
+ "Open Labeler WebUI": "라벨러 WebUI 열기",
+ "Open Tensorboard": "Tensorboard 열기",
+ "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
+ "Optional Label Language": "선택적 라벨 언어",
+ "Optional online ver": "온라인 버전 선택",
+ "Output Path": "출력 경로",
+ "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
+ "Precision": "정밀도",
+ "Probability of applying Speaker Condition": "화자 조건 적용 확률",
+ "Put your text here.": "여기에 텍스트를 입력하세요.",
+ "Reference Audio": "참고 오디오",
+ "Reference Text": "참고 텍스트",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
+ "Remove Selected Data": "선택한 데이터 제거",
+ "Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
+ "Repetition Penalty": "반복 패널티",
+ "Save model every n steps": "n 단계마다 모델 저장",
+ "Select LLAMA ckpt": "LLAMA ckpt 선택",
+ "Select VITS ckpt": "VITS ckpt 선택",
+ "Select VQGAN ckpt": "VQGAN ckpt 선택",
+ "Select source file processing method": "소스 파일 처리 방법 선택",
+ "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
+ "Selected: {}": "선택됨: {}",
+ "Speaker": "화자",
+ "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
+ "Start Training": "학습 시작",
+ "Streaming Audio": "스트리밍 오디오",
+ "Streaming Generate": "스트리밍 생성",
+ "Tensorboard Host": "Tensorboard 호스트",
+ "Tensorboard Log Path": "Tensorboard 로그 경로",
+ "Tensorboard Port": "Tensorboard 포트",
+ "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
+ "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
+ "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
+ "Training Configuration": "학습 설정",
+ "Training Error": "학습 오류",
+ "Training stopped": "학습이 중지되었습니다.",
+ "Type name of the speaker": "화자의 이름을 입력하세요.",
+ "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
+ "Use LoRA": "LoRA 사용",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
+ "Use filelist": "파일 목록 사용",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
+ "VITS Configuration": "VITS 설정",
+ "VQGAN Configuration": "VQGAN 설정",
+ "Validation Batch Size": "검증 배치 크기",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
+ "WebUI Host": "WebUI 호스트",
+ "WebUI Port": "WebUI 포트",
+ "Whisper Model": "Whisper 모델",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
+ "latest": "최신",
+ "new": "새로운",
+ "Realtime Transform Text": "실시간 텍스트 변환",
+ "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
+ "Text Normalization": "텍스트 정규화",
+ "Select Example Audio": "예시 오디오 선택"
+}
diff --git a/fish_speech/i18n/locale/pt_BR.json b/fish_speech/i18n/locale/pt_BR.json
new file mode 100644
index 0000000000000000000000000000000000000000..385f20272e19053ab9b6cf6463a84c8ece768c68
--- /dev/null
+++ b/fish_speech/i18n/locale/pt_BR.json
@@ -0,0 +1,133 @@
+{
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
+ "Add to Processing Area": "Adicionar à Área de Processamento",
+ "Added path successfully!": "Caminho adicionado com sucesso!",
+ "Advanced Config": "Configuração Avançada",
+ "Base LLAMA Model": "Modelo LLAMA Base",
+ "Batch Inference": "Inferência em Lote",
+ "Batch Size": "Tamanho do Lote",
+ "Changing with the Model Path": "Alterando com o Caminho do Modelo",
+
+ "Compile Model": "Compilar Modelo",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
+ "Copy": "Copiar",
+ "Data Preprocessing": "Pré-processamento de Dados",
+ "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
+ "Data Source": "Fonte de Dados",
+ "Decoder Model Config": "Configuração do Modelo Decodificador",
+ "Decoder Model Path": "Caminho do Modelo Decodificador",
+ "Disabled": "Desativado",
+ "Enable Initial Prompt": "Habilitar Prompt Inicial",
+ "Enable Reference Audio": "Habilitar Áudio de Referência",
+ "English": "Inglês",
+ "Japanese": "Japonês",
+ "Chinese": "Chinês",
+ "Portuguese": "Português",
+ "Spanish": "Espanhol",
+ "Error Message": "Mensagem de Erro",
+ "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
+ "File Preprocessing": "Pré-processamento de Arquivos",
+ "Generate": "Gerar",
+ "Generated Audio": "Áudio Gerado",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
+ "Infer interface is closed": "A interface de inferência foi fechada",
+ "Inference Configuration": "Configuração de Inferência",
+ "Inference Server Configuration": "Configuração do Servidor de Inferência",
+ "Inference Server Error": "Erro do Servidor de Inferência",
+ "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
+ "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
+ "Initial Prompt": "Prompt Inicial",
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
+ "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
+ "Input Text": "Texto de Entrada",
+ "Invalid path: {}": "Caminho inválido: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
+ "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
+ "LLAMA Configuration": "Configuração do LLAMA",
+ "LLAMA Model Config": "Configuração do Modelo LLAMA",
+ "LLAMA Model Path": "Caminho do Modelo LLAMA",
+ "Labeling Device": "Dispositivo de Rotulagem",
+ "LoRA Model to be merged": "Modelo LoRA para mesclagem",
+ "Maximum Length per Sample": "Comprimento Máximo por Amostra",
+ "Maximum Training Steps": "Etapas Máximas de Treinamento",
+ "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
+ "Merge": "Mesclar",
+ "Merge LoRA": "Mesclar LoRA",
+ "Merge successfully": "Mesclado com sucesso",
+ "Model Output Path": "Caminho de Saída do Modelo",
+ "Model Quantization": "Quantização do Modelo",
+ "Model Size": "Tamanho do Modelo",
+ "Move": "Mover",
+ "Move files successfully": "Arquivos movidos com sucesso",
+ "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
+ "No selected options": "Nenhuma opção selecionada",
+ "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
+ "Number of Workers": "Número de Processos",
+ "Open Inference Server": "Abrir Servidor de Inferência",
+ "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
+ "Open Tensorboard": "Abrir Tensorboard",
+ "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
+ "Optional Label Language": "Idioma do Rótulo (Opcional)",
+ "Optional online ver": "Versão online (opcional)",
+ "Output Path": "Caminho de Saída",
+ "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
+ "Post-quantification Precision": "Precisão Pós-quantização",
+ "Precision": "Precisão",
+ "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
+ "Put your text here.": "Insira seu texto aqui.",
+ "Quantify": "Quantizar",
+ "Quantify successfully": "Quantizado com sucesso",
+ "Realtime Transform Text": "Transformar Texto em Tempo Real",
+ "Reference Audio": "Áudio de Referência",
+ "Reference Text": "Texto de Referência",
+ "warning": "Aviso",
+ "Pre-processing begins...": "O pré-processamento começou!",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
+ "Remove Selected Data": "Remover Dados Selecionados",
+ "Removed path successfully!": "Caminho removido com sucesso!",
+ "Repetition Penalty": "Penalidade de Repetição",
+ "Save model every n steps": "Salvar modelo a cada n etapas",
+ "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
+ "Select source file processing method": "Escolha como processar o arquivo de origem",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
+ "Selected: {}": "Selecionado: {}",
+ "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
+ "Start Training": "Iniciar Treinamento",
+ "Streaming Audio": "Áudio em Streaming",
+ "Streaming Generate": "Geração em Streaming",
+ "Tensorboard Host": "Host do Tensorboard",
+ "Tensorboard Log Path": "Caminho de Log do Tensorboard",
+ "Tensorboard Port": "Porta do Tensorboard",
+ "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
+ "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
+ "Text Normalization": "Normalização de Texto",
+ "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
+ "Training Configuration": "Configuração de Treinamento",
+ "Training Error": "Erro de Treinamento",
+ "Training stopped": "Treinamento interrompido!",
+ "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
+ "Use LoRA": "Usar LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
+ "Use filelist": "Usar lista de arquivos",
+ "VQGAN Configuration": "Configuração do VQGAN",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
+ "WebUI Host": "Host da WebUI",
+ "WebUI Port": "Porta da WebUI",
+ "Whisper Model": "Modelo Whisper",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
+ "auto": "automático",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
+ "latest": "mais recente",
+ "new": "novo",
+ "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
+ "You don't need to train this model!": "Não é necessário treinar este modelo!",
+ "Yes": "Sim",
+ "No": "Não",
+ "version:": "versão:",
+ "author:": "autor:"
+}
diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json
new file mode 100644
index 0000000000000000000000000000000000000000..9068ef0b9a41b9941b37644c6a4c96ec6a5d836e
--- /dev/null
+++ b/fish_speech/i18n/locale/zh_CN.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
+ "Accumulate Gradient Batches": "梯度累积批次",
+ "Add to Processing Area": "加入处理区",
+ "Added path successfully!": "添加路径成功!",
+ "Advanced Config": "高级参数",
+ "Base LLAMA Model": "基础 LLAMA 模型",
+ "Batch Inference": "批量推理",
+ "Batch Size": "批次大小",
+ "Changing with the Model Path": "随模型路径变化",
+ "Chinese": "中文",
+ "Compile Model": "编译模型",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
+ "Copy": "复制",
+ "Data Preprocessing": "数据预处理",
+ "Data Preprocessing Path": "数据预处理路径",
+ "Data Source": "数据源",
+ "Decoder Model Config": "解码器模型配置",
+ "Decoder Model Path": "解码器模型路径",
+ "Disabled": "禁用",
+ "Enable Reference Audio": "启用参考音频",
+ "English": "英文",
+ "Error Message": "错误信息",
+ "File Preprocessing": "文件预处理",
+ "Generate": "生成",
+ "Generated Audio": "音频",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
+ "Infer interface is closed": "推理界面已关闭",
+ "Inference Configuration": "推理配置",
+ "Inference Server Configuration": "推理服务器配置",
+ "Inference Server Error": "推理服务器错误",
+ "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
+ "Initial Learning Rate": "初始学习率",
+ "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
+ "Input Text": "输入文本",
+ "Invalid path: {}": "无效路径: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
+ "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
+ "Japanese": "日文",
+ "LLAMA Configuration": "LLAMA 配置",
+ "LLAMA Model Config": "LLAMA 模型配置",
+ "LLAMA Model Path": "LLAMA 模型路径",
+ "Labeling Device": "标注加速设备",
+ "LoRA Model to be merged": "要合并的 LoRA 模型",
+ "Maximum Audio Duration": "最大音频时长",
+ "Maximum Length per Sample": "每个样本的最大长度",
+ "Maximum Training Steps": "最大训练步数",
+ "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
+ "Merge": "合并",
+ "Merge LoRA": "合并 LoRA",
+ "Merge successfully": "合并成功",
+ "Minimum Audio Duration": "最小音频时长",
+ "Model Output Path": "模型输出路径",
+ "Model Size": "模型规模",
+ "Move": "移动",
+ "Move files successfully": "移动文件成功",
+ "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
+ "No selected options": "没有选择的选项",
+ "Number of Workers": "数据加载进程数",
+ "Open Inference Server": "打开推理服务器",
+ "Open Labeler WebUI": "打开标注工具",
+ "Open Tensorboard": "打开 Tensorboard",
+ "Opened labeler in browser": "在浏览器中打开标注工具",
+ "Optional Label Language": "[可选] 标注语言",
+ "Optional online ver": "[可选] 使用在线版",
+ "Output Path": "输出路径",
+ "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
+ "Precision": "精度",
+ "Probability of applying Speaker Condition": "应用说话人条件的概率",
+ "Put your text here.": "在此处输入文本.",
+ "Reference Audio": "参考音频",
+ "Reference Text": "参考文本",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
+ "Remove Selected Data": "移除选中数据",
+ "Removed path successfully!": "移除路径成功!",
+ "Repetition Penalty": "重复惩罚",
+ "Save model every n steps": "每 n 步保存模型",
+ "Select LLAMA ckpt": "选择 LLAMA 检查点",
+ "Select VITS ckpt": "选择 VITS 检查点",
+ "Select VQGAN ckpt": "选择 VQGAN 检查点",
+ "Select source file processing method": "选择源文件处理方法",
+ "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
+ "Selected: {}": "已选择: {}",
+ "Speaker": "说话人",
+ "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
+ "Start Training": "开始训练",
+ "Streaming Audio": "流式音频",
+ "Streaming Generate": "流式合成",
+ "Tensorboard Host": "Tensorboard 监听地址",
+ "Tensorboard Log Path": "Tensorboard 日志路径",
+ "Tensorboard Port": "Tensorboard 端口",
+ "Tensorboard interface is closed": "Tensorboard 界面已关闭",
+ "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
+ "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
+ "Training Configuration": "训练配置",
+ "Training Error": "训练错误",
+ "Training stopped": "训练已停止",
+ "Type name of the speaker": "输入说话人的名称",
+ "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
+ "Use LoRA": "使用 LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
+ "Use filelist": "使用文件列表",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
+ "VITS Configuration": "VITS 配置",
+ "VQGAN Configuration": "VQGAN 配置",
+ "Validation Batch Size": "验证批次大小",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
+ "WebUI Host": "WebUI 监听地址",
+ "WebUI Port": "WebUI 端口",
+ "Whisper Model": "Whisper 模型",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
+ "latest": "最近的检查点",
+ "new": "创建新的检查点",
+ "Realtime Transform Text": "实时规范化文本",
+ "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
+ "Text Normalization": "文本规范化",
+ "Select Example Audio": "选择参考音频"
+}
diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0194c0f1a31dc95309c64626d13f04751a44ba1
--- /dev/null
+++ b/fish_speech/i18n/scan.py
@@ -0,0 +1,122 @@
+import ast
+import glob
+import json
+from collections import OrderedDict
+from pathlib import Path
+
+from loguru import logger
+
+from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
+
+
+def extract_i18n_strings(node):
+ i18n_strings = []
+
+ if (
+ isinstance(node, ast.Call)
+ and isinstance(node.func, ast.Name)
+ and node.func.id == "i18n"
+ ):
+ for arg in node.args:
+ if isinstance(arg, ast.Str):
+ i18n_strings.append(arg.s)
+
+ for child_node in ast.iter_child_nodes(node):
+ i18n_strings.extend(extract_i18n_strings(child_node))
+
+ return i18n_strings
+
+
+# scan the directory for all .py files (recursively)
+# for each file, parse the code into an AST
+# for each AST, extract the i18n strings
+
+strings = []
+folders = ["fish_speech", "tools"]
+# for filename in glob.iglob("**/*.py", recursive=True):
+for folder in folders:
+ for f in Path(folder).rglob("*.py"):
+ code = f.read_text(encoding="utf-8")
+ if "i18n(" in code:
+ tree = ast.parse(code)
+ i18n_strings = extract_i18n_strings(tree)
+ logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
+ strings.extend(i18n_strings)
+
+code_keys = set(strings)
+logger.info(f"Total unique: {len(code_keys)}")
+
+
+standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
+with open(standard_file, "r", encoding="utf-8") as f:
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
+standard_keys = set(standard_data.keys())
+
+# Define the standard file name
+unused_keys = standard_keys - code_keys
+logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
+for unused_key in unused_keys:
+ logger.info(f"\t{unused_key}")
+
+missing_keys = code_keys - standard_keys
+logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
+for missing_key in missing_keys:
+ logger.info(f"\t{missing_key}")
+
+code_keys_dict = OrderedDict()
+for s in strings:
+ code_keys_dict[s] = s
+
+# write back
+with open(standard_file, "w", encoding="utf-8") as f:
+ json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
+ f.write("\n")
+
+logger.info(f"Updated {standard_file}")
+
+
+# Define the standard file name
+standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
+
+# Find all JSON files in the directory
+dir_path = I18N_FILE_PATH
+languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
+
+# Load the standard file
+with open(standard_file, "r", encoding="utf-8") as f:
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
+
+# Loop through each language file
+for lang_file in languages:
+ # Load the language file
+ with open(lang_file, "r", encoding="utf-8") as f:
+ lang_data = json.load(f, object_pairs_hook=OrderedDict)
+
+ # Find the difference between the language file and the standard file
+ diff = set(standard_data.keys()) - set(lang_data.keys())
+
+ miss = set(lang_data.keys()) - set(standard_data.keys())
+
+ # Add any missing keys to the language file
+ for key in diff:
+ lang_data[key] = "#!" + key
+ logger.info(f"Added missing key: {key} to {lang_file}")
+
+ # Del any extra keys to the language file
+ for key in miss:
+ del lang_data[key]
+ logger.info(f"Del extra key: {key} from {lang_file}")
+
+ # Sort the keys of the language file to match the order of the standard file
+ lang_data = OrderedDict(
+ sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
+ )
+
+ # Save the updated language file
+ with open(lang_file, "w", encoding="utf-8") as f:
+ json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
+ f.write("\n")
+
+ logger.info(f"Updated {lang_file}")
+
+logger.info("Done")
diff --git a/fish_speech/models/text2semantic/__init__.py b/fish_speech/models/text2semantic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d8b2061c891167ac04d56037461230d0bfcf642
Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc40d2aa932be0e9841ed3a7b7605d21f9e365ab
Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc differ
diff --git a/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2dc7d6fd91af83be0aa28000779af1bc0d342ac7
Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc differ
diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..df970400f8a073be4c4166a697245fabdf6b09b0
--- /dev/null
+++ b/fish_speech/models/text2semantic/lit_module.py
@@ -0,0 +1,202 @@
+from typing import Any, Optional
+
+import lightning as L
+import torch
+import torch.nn.functional as F
+from lightning.pytorch.utilities.types import OptimizerLRScheduler
+
+import fish_speech.utils as utils
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.models.text2semantic.llama import NaiveTransformer
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+class TextToSemantic(L.LightningModule):
+ def __init__(
+ self,
+ model: NaiveTransformer,
+ optimizer: Any,
+ lr_scheduler: Any,
+ ):
+ super().__init__()
+
+ self.model = model
+ self.optimizer_builder = optimizer
+ self.lr_scheduler_builder = lr_scheduler
+
+ def forward(self, x):
+ return self.model(x)
+
+ def on_save_checkpoint(self, checkpoint):
+ # Save only LoRA parameters
+ state_dict = checkpoint["state_dict"]
+ use_lora = any("lora" in name for name in state_dict.keys())
+ if not use_lora:
+ return
+
+ for name in list(state_dict.keys()):
+ if "lora" not in name:
+ state_dict.pop(name)
+
+ def configure_optimizers(self) -> OptimizerLRScheduler:
+ # Get weight decay parameters
+ weight_decay_parameters, other_parameters = [], []
+ for name, param in self.named_parameters():
+ if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
+ other_parameters.append(param)
+ else:
+ weight_decay_parameters.append(param)
+
+ optimizer = self.optimizer_builder(
+ [
+ {"params": weight_decay_parameters},
+ {"params": other_parameters, "weight_decay": 0.0},
+ ]
+ )
+
+ # Print the parameters and their weight decay
+ for i in optimizer.param_groups:
+ log.info(
+ f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
+ )
+
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": lr_scheduler,
+ "interval": "step",
+ },
+ }
+
+ # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
+ def get_batch_logps(
+ self,
+ logits: torch.FloatTensor,
+ labels: torch.LongTensor,
+ average_log_prob: bool = False,
+ ) -> torch.FloatTensor:
+ """Compute the log probabilities of the given labels under the given logits.
+
+ Args:
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
+
+ Returns:
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
+ """
+ assert logits.shape[:-1] == labels.shape
+
+ labels = labels.clone()
+ loss_mask = labels != -100
+
+ # dummy token; we'll ignore the losses on these tokens later
+ labels[labels == -100] = 0
+
+ per_token_logps = torch.gather(
+ logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ if average_log_prob:
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
+ else:
+ return (per_token_logps * loss_mask).sum(-1)
+
+ def _step(self, batch, batch_idx, stage: str):
+ is_train = stage == "train"
+
+ if is_train:
+ # Key part to make lora work
+ # Otherwise the parameters are merged, which lead to incorrect gradients
+ self.model.train()
+
+ # Do positive and negative samples in the same batch to speed up training
+ labels = batch["labels"]
+ outputs = self.model(
+ inp=batch["inputs"],
+ key_padding_mask=batch["attention_masks"],
+ )
+ token_logits = outputs.token_logits
+ codebook_logits = outputs.codebook_logits
+
+ # Generate labels
+ base_loss = F.cross_entropy(
+ token_logits.view(-1, token_logits.size(-1)),
+ labels[:, 0].reshape(-1),
+ ignore_index=-100,
+ )
+
+ codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
+ semantic_loss = F.cross_entropy(
+ codebook_logits.view(-1, codebook_logits.size(-1)),
+ codebook_labels.reshape(-1),
+ ignore_index=-100,
+ )
+
+ loss = base_loss + semantic_loss
+
+ self.log(
+ f"{stage}/loss",
+ loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ self.log(
+ f"{stage}/base_loss",
+ base_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ self.log(
+ f"{stage}/semantic_loss",
+ semantic_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ # Top-5 accuracy
+ accuracy = self.get_accuracy(codebook_logits, codebook_labels)
+ self.log(
+ f"{stage}/top_5_accuracy",
+ accuracy,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ return loss
+
+ def get_accuracy(self, logits, labels):
+ mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
+ if mask.sum() == 0:
+ return torch.tensor(0.0, device=logits.device)
+
+ _, indices = logits.topk(5, dim=-1)
+ correct = indices.eq(labels.unsqueeze(-1))
+ correct[~mask] = 0
+ correct = correct.sum()
+ accuracy = correct / mask.sum()
+
+ return accuracy
+
+ def training_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "train")
+
+ def validation_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "val")
diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..1811f091ce282bbd55615747f2c21c3005133616
--- /dev/null
+++ b/fish_speech/models/text2semantic/llama.py
@@ -0,0 +1,887 @@
+import dataclasses
+import json
+import math
+from collections import OrderedDict
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from loguru import logger
+from torch import Tensor
+from torch.nn import functional as F
+from torch.nn.attention import SDPBackend, sdpa_kernel
+from torch.utils.checkpoint import checkpoint
+from transformers import AutoTokenizer
+
+from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
+from fish_speech.utils import RankedLogger
+
+from .lora import LoraConfig, setup_lora
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+
+@dataclass
+class BaseModelArgs:
+ model_type: str = "base"
+
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ max_seq_len: int = 2048
+ dropout: float = 0.0
+ tie_word_embeddings: bool = True
+ attention_qkv_bias: bool = False
+
+ # Codebook configs
+ codebook_size: int = 160
+ num_codebooks: int = 4
+
+ # Gradient checkpointing
+ use_gradient_checkpointing: bool = True
+
+ # Initialize the model
+ initializer_range: float = 0.02
+
+ # Dummy vars
+ is_reward_model: bool = False
+ share_codebook_embeddings: bool = True
+ scale_codebook_embeddings: bool = False
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ self.head_dim = self.dim // self.n_head
+
+ @staticmethod
+ def from_pretrained(path: str):
+ path = Path(path)
+
+ if path.is_dir():
+ path = path / "config.json"
+
+ with open(path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ match data["model_type"]:
+ case "naive":
+ cls = NaiveModelArgs
+ case "dual_ar":
+ cls = DualARModelArgs
+ case _:
+ raise ValueError(f"Unknown model type: {data['model_type']}")
+
+ return cls(**data)
+
+ def save(self, path: str):
+ with open(path, "w") as f:
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
+
+
+@dataclass
+class NaiveModelArgs(BaseModelArgs):
+ model_type: str = "naive"
+
+
+@dataclass
+class DualARModelArgs(BaseModelArgs):
+ model_type: str = "dual_ar"
+ n_fast_layer: int = 4
+ fast_dim: int | None = None
+ fast_n_head: int | None = None
+ fast_n_local_heads: int | None = None
+ fast_head_dim: int | None = None
+ fast_intermediate_size: int | None = None
+ fast_attention_qkv_bias: bool | None = None
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ self.fast_dim = self.fast_dim or self.dim
+ self.fast_n_head = self.fast_n_head or self.n_head
+ self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
+ self.fast_head_dim = self.fast_head_dim or self.head_dim
+ self.fast_intermediate_size = (
+ self.fast_intermediate_size or self.intermediate_size
+ )
+ self.fast_attention_qkv_bias = (
+ self.fast_attention_qkv_bias
+ if self.fast_attention_qkv_bias is not None
+ else self.attention_qkv_bias
+ )
+
+
+class KVCache(nn.Module):
+ def __init__(
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
+ ):
+ super().__init__()
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ assert input_pos.shape[0] == k_val.shape[2]
+
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:, :, input_pos] = k_val
+ v_out[:, :, input_pos] = v_val
+
+ return k_out, v_out
+
+
+@dataclass
+class TransformerForwardResult:
+ token_logits: Tensor
+ codebook_logits: Tensor
+
+
+@dataclass
+class BaseTransformerForwardResult:
+ logits: Tensor
+ hidden_states: Tensor
+
+
+class BaseTransformer(nn.Module):
+ def __init__(
+ self,
+ config: BaseModelArgs,
+ tokenizer: FishTokenizer | AutoTokenizer,
+ init_weights: bool = True,
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.tokenizer = tokenizer
+ self.semantic_token_ids = [
+ tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
+ ]
+
+ # Slow transformer
+ self.embeddings = nn.Embedding(
+ config.vocab_size,
+ config.dim,
+ )
+ self.codebook_embeddings = nn.Embedding(
+ config.codebook_size * config.num_codebooks,
+ config.dim,
+ )
+ self.layers = nn.ModuleList(
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
+ )
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+
+ if self.config.tie_word_embeddings is False:
+ self.output = nn.Linear(
+ config.dim,
+ config.vocab_size,
+ bias=False,
+ )
+
+ self.register_buffer(
+ "freqs_cis",
+ precompute_freqs_cis(
+ config.max_seq_len,
+ config.dim // config.n_head,
+ config.rope_base,
+ ),
+ persistent=False,
+ )
+ self.register_buffer(
+ "causal_mask",
+ torch.tril(
+ torch.ones(
+ config.max_seq_len,
+ config.max_seq_len,
+ dtype=torch.bool,
+ )
+ ),
+ persistent=False,
+ )
+
+ # For kv cache
+ self.max_batch_size = -1
+ self.max_seq_len = -1
+
+ if init_weights:
+ self.apply(self._init_weights)
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
+ return
+
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_len = find_multiple(max_seq_len, 8)
+ self.max_seq_len = max_seq_len
+ self.max_batch_size = max_batch_size
+
+ for b in self.layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ max_seq_len,
+ self.config.n_local_heads,
+ head_dim,
+ dtype=dtype,
+ )
+
+ def embed(self, x: Tensor) -> Tensor:
+ vocab_embeds = [self.embeddings(x[:, 0])]
+ for i in range(self.config.num_codebooks):
+ emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
+ semantic_token_ids_tensor = torch.tensor(
+ self.semantic_token_ids, device=x.device
+ )
+ emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0
+
+ x = torch.stack(vocab_embeds, dim=3)
+ x = x.sum(dim=3)
+
+ return x
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> BaseTransformerForwardResult:
+ seq_len = inp.size(2)
+
+ # Here we want to merge the embeddings of the codebooks
+ x = self.embed(inp)
+
+ freqs_cis = self.freqs_cis[:seq_len]
+
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
+ # That is, FALSE means masked out
+ # To maintain consistency, key_padding_mask use TRUE to mask out
+ mask = None
+ if key_padding_mask is not None:
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
+ mask = mask & key_padding_mask[:, None, None, :].logical_not()
+
+ for layer in self.layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
+ else:
+ x = layer(x, freqs_cis, mask)
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ if self.config.tie_word_embeddings:
+ token_logits = F.linear(slow_out, self.embeddings.weight)
+ else:
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def forward_generate(
+ self,
+ inp: Tensor,
+ input_pos: Optional[Tensor] = None,
+ vq_masks: Optional[Tensor] = None, # this is not used in fact
+ return_all: bool = False,
+ ) -> BaseTransformerForwardResult:
+ # This is used for generation, optimized for torch compile
+ # assert (
+ # self.max_seq_len != -1 and self.max_batch_size != -1
+ # ), "Please call setup_caches before forward_generate"
+
+ embeds = []
+ for i in range(self.config.num_codebooks):
+ if self.config.share_codebook_embeddings:
+ _tokens = inp[:, i + 1] + i * self.config.codebook_size
+ else:
+ _tokens = inp[:, i + 1]
+
+ emb = self.codebook_embeddings(_tokens)
+ embeds.append(emb)
+
+ vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
+ # if self.config.use_codebook_mlp:
+ # vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks
+ # vq_embeds_sum = self.codebook_mlp(vq_embeds_sum)
+
+ vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
+ inp[:, 0] <= self.tokenizer.semantic_end_id
+ )
+
+ vq_embeds_sum[~vq_masks] = 0
+ x = self.embeddings(inp[:, 0]) + vq_embeds_sum
+
+ if input_pos is None:
+ input_pos = torch.arange(inp.shape[-1], device=x.device)
+ max_seq_len = inp.shape[-1]
+ else:
+ max_seq_len = self.max_seq_len
+
+ mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
+ freqs_cis = self.freqs_cis[input_pos]
+
+ for layer in self.layers:
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
+
+ # If prefill, we only calculate the logits of last token
+ if x.size(1) > 1 and not return_all:
+ x = x[:, -1:]
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ if self.config.is_reward_model:
+ token_logits = self.score_output(slow_out)
+ elif self.config.tie_word_embeddings:
+ token_logits = F.linear(slow_out, self.embeddings.weight)
+ else:
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ @staticmethod
+ def from_pretrained(
+ path: str,
+ load_weights: bool = False,
+ max_length: int | None = None,
+ lora_config: LoraConfig | None = None,
+ rope_base: int | None = None,
+ is_agent: bool = False,
+ ) -> "BaseTransformer":
+ config = BaseModelArgs.from_pretrained(str(path))
+ if max_length is not None:
+ config.max_seq_len = max_length
+ log.info(f"Override max_seq_len to {max_length}")
+
+ if rope_base is not None:
+ config.rope_base = rope_base
+ log.info(f"Override rope_base to {rope_base}")
+
+ match config.model_type:
+ case "naive":
+ model_cls = NaiveTransformer
+ case "dual_ar":
+ model_cls = DualARTransformer
+ case _:
+ raise ValueError(f"Unknown model type: {config.model_type}")
+
+ if is_agent:
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
+ else:
+ tokenizer_path = str(path) + "/tokenizer.tiktoken"
+ tokenizer = FishTokenizer(tokenizer_path)
+
+ log.info(f"Loading model from {path}, config: {config}")
+ model = model_cls(config, tokenizer=tokenizer)
+
+ if lora_config is not None:
+ setup_lora(model, lora_config)
+ log.info(f"LoRA setup: {lora_config}")
+
+ if load_weights is False:
+ log.info("Randomly initialized model")
+ else:
+
+ if "int8" in str(Path(path)):
+ logger.info("Using int8 weight-only quantization!")
+ from tools.llama.quantize import WeightOnlyInt8QuantHandler
+
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
+ model = simple_quantizer.convert_for_runtime()
+
+ if "int4" in str(Path(path)):
+ logger.info("Using int4 quantization!")
+ path_comps = path.name.split("-")
+ assert path_comps[-2].startswith("g")
+ groupsize = int(path_comps[-2][1:])
+ from tools.llama.quantize import WeightOnlyInt4QuantHandler
+
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
+ model = simple_quantizer.convert_for_runtime()
+
+ weights = torch.load(
+ Path(path) / "model.pth",
+ map_location="cpu",
+ mmap=True,
+ weights_only=True,
+ )
+
+ if "state_dict" in weights:
+ logger.warning(
+ "Using a TextToSemantic LightningModule checkpoint, "
+ "please make sure it is a full model, not a LoRA model."
+ )
+ weights = weights["state_dict"]
+
+ if next(iter(weights.keys())).startswith("model."):
+ logger.info(
+ f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
+ )
+ new_weights = OrderedDict()
+ for k, v in weights.items():
+ new_weights[k.replace("model.", "")] = v
+ weights = new_weights
+
+ # Verify the name and shape of parameters since strict=False in load_state_dict.
+ for k, v in model.named_parameters():
+ if k not in weights:
+ logger.warning(f"No weight for {k}")
+ elif v.shape != weights[k].shape:
+ logger.warning(
+ f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
+ )
+
+ err = model.load_state_dict(weights, strict=False, assign=True)
+ log.info(f"Loaded weights with error: {err}")
+
+ return model
+
+ def save_pretrained(self, path: str, drop_lora: bool = False):
+ path = Path(path)
+ path.mkdir(parents=True, exist_ok=True)
+
+ self.config.save(path / "config.json")
+ state_dict = self.state_dict()
+
+ if drop_lora:
+ for key in list(state_dict.keys()):
+ if "lora" not in key:
+ continue
+
+ state_dict.pop(key)
+ log.info(f"Drop LoRA parameter: {key}")
+
+ torch.save(state_dict, path / "model.pth")
+ self.tokenizer.save_pretrained(path)
+
+
+class NaiveTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
+
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.codebook_output = nn.Linear(
+ config.dim,
+ config.codebook_size * config.num_codebooks,
+ bias=False,
+ )
+
+ self.apply(self._init_weights)
+
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
+ token_logits = result.logits
+ x = result.hidden_states
+
+ # Codebook
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
+ codebook_logits = rearrange(
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ result = super().forward(
+ inp=inp,
+ key_padding_mask=key_padding_mask,
+ )
+ return self.decode(result)
+
+ def forward_generate(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> TransformerForwardResult:
+ result = super().forward_generate(x, input_pos)
+ return self.decode(result)
+
+
+class DualARTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
+
+ # Project to fast dim if needed
+ if config.fast_dim is not None and config.fast_dim != config.dim:
+ self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
+ else:
+ self.fast_project_in = nn.Identity()
+
+ # Fast transformer
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
+
+ # The equivalent bs is so large that sdpa doesn't work
+ override_config = dataclasses.replace(
+ config,
+ dim=config.fast_dim,
+ n_head=config.fast_n_head,
+ n_local_heads=config.fast_n_local_heads,
+ head_dim=config.fast_head_dim,
+ intermediate_size=config.fast_intermediate_size,
+ attention_qkv_bias=config.fast_attention_qkv_bias,
+ )
+
+ self.fast_layers = nn.ModuleList(
+ TransformerBlock(override_config, use_sdpa=False)
+ for _ in range(config.n_fast_layer)
+ )
+ self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
+ self.fast_output = nn.Linear(
+ config.fast_dim,
+ config.codebook_size,
+ bias=False,
+ )
+
+ self.register_buffer(
+ "fast_freqs_cis",
+ precompute_freqs_cis(
+ config.num_codebooks,
+ config.fast_dim // config.fast_n_head,
+ config.rope_base,
+ ),
+ persistent=False,
+ )
+ self.apply(self._init_weights)
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
+
+ head_dim = self.config.fast_dim // self.config.fast_n_head
+
+ # Fast transformer
+ # The max seq len here is the number of codebooks
+ for b in self.fast_layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ self.config.num_codebooks,
+ self.config.fast_n_local_heads,
+ head_dim,
+ dtype=dtype,
+ )
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ parent_result = super().forward(inp, key_padding_mask)
+ token_logits = parent_result.logits
+ x = parent_result.hidden_states
+ x = self.fast_project_in(x)
+
+ # Fast transformer
+ fast_seq_len = self.config.num_codebooks
+ fast_mask = self.causal_mask[
+ None, None, :fast_seq_len, :fast_seq_len
+ ] # (B, N, Q, K)
+
+ # Drop the last token and rotate left
+ codebooks = inp[:, 1:-1, 1:]
+ codebooks = F.pad(codebooks, (0, 1), value=0)
+ codebook_embeddings = self.fast_embeddings(codebooks)
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
+ b, s = x.size(0), x.size(2)
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
+
+ # Remove padded part
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
+ codebook_mask = (codebooks == 0).all(dim=-1)
+
+ if torch.all(codebook_mask):
+ # If all codebooks are padded, we keep first 8 to make sure the model runs
+ codebook_mask[:8] = False
+
+ x_bs, x_len = x.size(0), x.size(1)
+ x = x[~codebook_mask]
+
+ for layer in self.fast_layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(
+ layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
+ )
+ else:
+ x = layer(x, self.fast_freqs_cis, fast_mask)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x)
+ codebook_logits = self.fast_output(fast_out)
+
+ # Re-pad the codebook_logits
+ buffer = torch.zeros(
+ x_bs,
+ x_len,
+ codebook_logits.size(-1),
+ device=codebook_logits.device,
+ dtype=codebook_logits.dtype,
+ )
+ buffer[~codebook_mask] = codebook_logits
+ codebook_logits = buffer
+
+ assert codebook_logits.shape[1] == self.config.num_codebooks
+ codebook_logits = rearrange(
+ codebook_logits,
+ "(b s) n d -> b s n d",
+ b=b,
+ s=s,
+ n=self.config.num_codebooks,
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward_generate_fast(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> Tensor:
+ # Fast transformer
+ x = x.view(1, 1, -1)
+
+ fast_mask = self.causal_mask[
+ None, None, input_pos, : self.config.num_codebooks
+ ] # (B, N, Q, K)
+ fast_freqs_cis = self.fast_freqs_cis[input_pos]
+
+ for layer in self.fast_layers:
+ x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x) # only take the last token
+ codebook_logits = self.fast_output(fast_out)
+
+ return codebook_logits
+
+ def forward_generate(
+ self,
+ x: Tensor,
+ input_pos: Optional[Tensor] = None,
+ vq_masks: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ x = super().forward_generate(x, input_pos, vq_masks)
+ x.hidden_states = self.fast_project_in(x.hidden_states)
+ return x
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
+ super().__init__()
+ self.attention = Attention(config, use_sdpa=use_sdpa)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
+
+ def forward(
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
+ ) -> Tensor:
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
+ out = h + self.feed_forward(self.ffn_norm(h))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
+ )
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.dropout = config.dropout
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ self.use_sdpa = use_sdpa
+ self._register_load_state_dict_pre_hook(self.load_hook)
+
+ def load_hook(self, state_dict, prefix, *args):
+ if prefix + "wq.weight" in state_dict:
+ wq = state_dict.pop(prefix + "wq.weight")
+ wk = state_dict.pop(prefix + "wk.weight")
+ wv = state_dict.pop(prefix + "wv.weight")
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+ def forward(
+ self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+
+ if self.use_sdpa:
+ if mask is None:
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=True,
+ # No third party attn_mask here to use flash_attention
+ )
+ else:
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+ else:
+ y = self.eq_scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+ return self.wo(y)
+
+ def eq_scaled_dot_product_attention(
+ self,
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ ) -> torch.Tensor:
+ # This is a standard scaled dot product attention
+ # It's low efficient, but it doesn't raise cuda error
+
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1))
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+
+ return attn_weight @ value
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: BaseModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
+ freqs = 1.0 / (
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+ )
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=torch.bfloat16)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..647ca6fcccf038e17d2cf91a2874281dff3e0938
--- /dev/null
+++ b/fish_speech/models/text2semantic/lora.py
@@ -0,0 +1,92 @@
+from dataclasses import dataclass
+
+import loralib as lora
+
+
+@dataclass
+class LoraConfig:
+ r: int
+ lora_alpha: float
+ lora_dropout: float = 0.0
+
+
+def setup_lora(model, lora_config):
+ # Replace the embedding layer with a LoRA layer
+ model.embeddings = lora.Embedding(
+ num_embeddings=model.embeddings.num_embeddings,
+ embedding_dim=model.embeddings.embedding_dim,
+ padding_idx=model.embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ model.codebook_embeddings = lora.Embedding(
+ num_embeddings=model.codebook_embeddings.num_embeddings,
+ embedding_dim=model.codebook_embeddings.embedding_dim,
+ padding_idx=model.codebook_embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ # Replace output layer with a LoRA layer
+ linears = [(model, "output")]
+
+ # Replace all linear layers with LoRA layers
+ for layer in model.layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ if hasattr(model, "fast_layers"):
+ model.fast_embeddings = lora.Embedding(
+ num_embeddings=model.fast_embeddings.num_embeddings,
+ embedding_dim=model.fast_embeddings.embedding_dim,
+ padding_idx=model.fast_embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ # Dual-AR model
+ linears.append((model, "fast_output"))
+
+ for layer in model.fast_layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ for module, layer in linears:
+ updated_linear = lora.Linear(
+ in_features=getattr(module, layer).in_features,
+ out_features=getattr(module, layer).out_features,
+ bias=getattr(module, layer).bias,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ lora_dropout=lora_config.lora_dropout,
+ )
+ setattr(module, layer, updated_linear)
+
+ # Mark only the LoRA layers as trainable
+ lora.mark_only_lora_as_trainable(model, bias="none")
+
+
+def get_merged_state_dict(model):
+ # This line will merge the state dict of the model and the LoRA parameters
+ model.eval()
+
+ # Then we need to remove the LoRA parameters from the state dict
+ state_dict = model.state_dict()
+ for name in list(state_dict.keys()):
+ if "lora" in name:
+ state_dict.pop(name)
+
+ return state_dict
diff --git a/fish_speech/models/vqgan/__init__.py b/fish_speech/models/vqgan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32d03c53a1dbe28c4e9906d8f0f38f97de8fcb33
Binary files /dev/null and b/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b13028d2185ba8125bd3ae0e4dadba988e5588f4
Binary files /dev/null and b/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc differ
diff --git a/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8dfafdedbad457d5796aad896c12b6d1d76401db
Binary files /dev/null and b/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc differ
diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py
new file mode 100644
index 0000000000000000000000000000000000000000..91fc9118cc26f4d99171e7db3ee871071a7a296a
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/firefly.py
@@ -0,0 +1,596 @@
+import math
+from functools import partial
+from math import prod
+from typing import Callable
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+from torch.utils.checkpoint import checkpoint
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv1D") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return (kernel_size * dilation - dilation) // 2
+
+
+def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left:end]
+
+
+def get_extra_padding_for_conv1d(
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> int:
+ """See `pad_for_conv1d`."""
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad1d(
+ x: torch.Tensor,
+ paddings: tuple[int, int],
+ mode: str = "zeros",
+ value: float = 0.0,
+):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right
+ before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == "reflect":
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+class FishConvNet(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
+ ):
+ super(FishConvNet, self).__init__()
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ )
+ self.stride = stride
+ self.kernel_size = (kernel_size - 1) * dilation + 1
+ self.dilation = dilation
+
+ def forward(self, x):
+ pad = self.kernel_size - self.stride
+ extra_padding = get_extra_padding_for_conv1d(
+ x, self.kernel_size, self.stride, pad
+ )
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
+ return self.conv(x).contiguous()
+
+ def weight_norm(self, name="weight", dim=0):
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
+ return self
+
+ def remove_parametrizations(self, name="weight"):
+ self.conv = remove_parametrizations(self.conv, name)
+ return self
+
+
+class FishTransConvNet(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
+ super(FishTransConvNet, self).__init__()
+ self.conv = nn.ConvTranspose1d(
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
+ )
+ self.stride = stride
+ self.kernel_size = kernel_size
+
+ def forward(self, x):
+ x = self.conv(x)
+ pad = self.kernel_size - self.stride
+ padding_right = math.ceil(pad)
+ padding_left = pad - padding_right
+ x = unpad1d(x, (padding_left, padding_right))
+ return x.contiguous()
+
+ def weight_norm(self, name="weight", dim=0):
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
+ return self
+
+ def remove_parametrizations(self, name="weight"):
+ self.conv = remove_parametrizations(self.conv, name)
+ return self
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super().__init__()
+
+ self.convs1 = nn.ModuleList(
+ [
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
+ ).weight_norm(),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
+ ).weight_norm(),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.silu(x)
+ xt = c1(xt)
+ xt = F.silu(xt)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_parametrizations(self):
+ for conv in self.convs1:
+ conv.remove_parametrizations()
+ for conv in self.convs2:
+ conv.remove_parametrizations()
+
+
+class ParallelBlock(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ kernel_sizes: tuple[int] = (3, 7, 11),
+ dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ ):
+ super().__init__()
+
+ assert len(kernel_sizes) == len(dilation_sizes)
+
+ self.blocks = nn.ModuleList()
+ for k, d in zip(kernel_sizes, dilation_sizes):
+ self.blocks.append(ResBlock1(channels, k, d))
+
+ def forward(self, x):
+ return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
+
+ def remove_parametrizations(self):
+ for block in self.blocks:
+ block.remove_parametrizations()
+
+
+class HiFiGANGenerator(nn.Module):
+ def __init__(
+ self,
+ *,
+ hop_length: int = 512,
+ upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
+ upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
+ resblock_kernel_sizes: tuple[int] = (3, 7, 11),
+ resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ num_mels: int = 128,
+ upsample_initial_channel: int = 512,
+ pre_conv_kernel_size: int = 7,
+ post_conv_kernel_size: int = 7,
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
+ ):
+ super().__init__()
+
+ assert (
+ prod(upsample_rates) == hop_length
+ ), f"hop_length must be {prod(upsample_rates)}"
+
+ self.conv_pre = FishConvNet(
+ num_mels,
+ upsample_initial_channel,
+ pre_conv_kernel_size,
+ stride=1,
+ ).weight_norm()
+
+ self.num_upsamples = len(upsample_rates)
+ self.num_kernels = len(resblock_kernel_sizes)
+
+ self.noise_convs = nn.ModuleList()
+ self.ups = nn.ModuleList()
+
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ FishTransConvNet(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ stride=u,
+ ).weight_norm()
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ self.resblocks.append(
+ ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
+ )
+
+ self.activation_post = post_activation()
+ self.conv_post = FishConvNet(
+ ch, 1, post_conv_kernel_size, stride=1
+ ).weight_norm()
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ x = self.conv_pre(x)
+
+ for i in range(self.num_upsamples):
+ x = F.silu(x, inplace=True)
+ x = self.ups[i](x)
+
+ if self.training and self.checkpointing:
+ x = checkpoint(
+ self.resblocks[i],
+ x,
+ use_reentrant=False,
+ )
+ else:
+ x = self.resblocks[i](x)
+
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_parametrizations(self):
+ for up in self.ups:
+ up.remove_parametrizations()
+ for block in self.resblocks:
+ block.remove_parametrizations()
+ self.conv_pre.remove_parametrizations()
+ self.conv_post.remove_parametrizations()
+
+
+# DropPath copied from timm library
+def drop_path(
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """ # noqa: E501
+
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
+
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class LayerNorm(nn.Module):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """ # noqa: E501
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape,)
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(
+ x, self.normalized_shape, self.weight, self.bias, self.eps
+ )
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None] * x + self.bias[:, None]
+ return x
+
+
+# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
+class ConvNeXtBlock(nn.Module):
+ r"""ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
+ dilation (int): Dilation for depthwise conv. Default: 1.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ dim: int,
+ drop_path: float = 0.0,
+ layer_scale_init_value: float = 1e-6,
+ mlp_ratio: float = 4.0,
+ kernel_size: int = 7,
+ dilation: int = 1,
+ ):
+ super().__init__()
+
+ self.dwconv = FishConvNet(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ # padding=int(dilation * (kernel_size - 1) / 2),
+ groups=dim,
+ ) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, int(mlp_ratio * dim)
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x, apply_residual: bool = True):
+ input = x
+
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ if self.gamma is not None:
+ x = self.gamma * x
+
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
+ x = self.drop_path(x)
+
+ if apply_residual:
+ x = input + x
+
+ return x
+
+
+class ConvNeXtEncoder(nn.Module):
+ def __init__(
+ self,
+ input_channels: int = 3,
+ depths: list[int] = [3, 3, 9, 3],
+ dims: list[int] = [96, 192, 384, 768],
+ drop_path_rate: float = 0.0,
+ layer_scale_init_value: float = 1e-6,
+ kernel_size: int = 7,
+ ):
+ super().__init__()
+ assert len(depths) == len(dims)
+
+ self.downsample_layers = nn.ModuleList()
+ stem = nn.Sequential(
+ FishConvNet(
+ input_channels,
+ dims[0],
+ kernel_size=7,
+ # padding=3,
+ # padding_mode="replicate",
+ # padding_mode="zeros",
+ ),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+ )
+ self.downsample_layers.append(stem)
+
+ for i in range(len(depths) - 1):
+ mid_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
+ )
+ self.downsample_layers.append(mid_layer)
+
+ self.stages = nn.ModuleList()
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+
+ cur = 0
+ for i in range(len(depths)):
+ stage = nn.Sequential(
+ *[
+ ConvNeXtBlock(
+ dim=dims[i],
+ drop_path=dp_rates[cur + j],
+ layer_scale_init_value=layer_scale_init_value,
+ kernel_size=kernel_size,
+ )
+ for j in range(depths[i])
+ ]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ for i in range(len(self.downsample_layers)):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+
+ return self.norm(x)
+
+
+class FireflyArchitecture(nn.Module):
+ def __init__(
+ self,
+ backbone: nn.Module,
+ head: nn.Module,
+ quantizer: nn.Module,
+ spec_transform: nn.Module,
+ ):
+ super().__init__()
+
+ self.backbone = backbone
+ self.head = head
+ self.quantizer = quantizer
+ self.spec_transform = spec_transform
+ self.downsample_factor = math.prod(self.quantizer.downsample_factor)
+
+ def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
+ if self.spec_transform is not None:
+ x = self.spec_transform(x)
+
+ x = self.backbone(x)
+ if mask is not None:
+ x = x * mask
+
+ if self.quantizer is not None:
+ vq_result = self.quantizer(x)
+ x = vq_result.z
+
+ if mask is not None:
+ x = x * mask
+
+ x = self.head(x, template=template)
+
+ if x.ndim == 2:
+ x = x[:, None, :]
+
+ if self.vq is not None:
+ return x, vq_result
+
+ return x
+
+ def encode(self, audios, audio_lengths):
+ audios = audios.float()
+
+ mels = self.spec_transform(audios)
+ mel_lengths = audio_lengths // self.spec_transform.hop_length
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ mels = mels * mel_masks_float_conv
+
+ # Encode
+ encoded_features = self.backbone(mels) * mel_masks_float_conv
+ feature_lengths = mel_lengths // self.downsample_factor
+
+ return self.quantizer.encode(encoded_features), feature_lengths
+
+ def decode(self, indices, feature_lengths) -> torch.Tensor:
+ mel_masks = sequence_mask(
+ feature_lengths * self.downsample_factor,
+ indices.shape[2] * self.downsample_factor,
+ )
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ audio_lengths = (
+ feature_lengths * self.downsample_factor * self.spec_transform.hop_length
+ )
+
+ audio_masks = sequence_mask(
+ audio_lengths,
+ indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
+ )
+ audio_masks_float_conv = audio_masks[:, None, :].float()
+
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
+ x = self.head(z) * audio_masks_float_conv
+
+ return x, audio_lengths
+
+ def remove_parametrizations(self):
+ if hasattr(self.backbone, "remove_parametrizations"):
+ self.backbone.remove_parametrizations()
+
+ if hasattr(self.head, "remove_parametrizations"):
+ self.head.remove_parametrizations()
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
diff --git a/fish_speech/models/vqgan/modules/fsq.py b/fish_speech/models/vqgan/modules/fsq.py
new file mode 100644
index 0000000000000000000000000000000000000000..954553bbfe0b7b18d348db6c03bf04fc0c916c4f
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/fsq.py
@@ -0,0 +1,116 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from vector_quantize_pytorch import GroupedResidualFSQ
+
+from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
+
+
+@dataclass
+class FSQResult:
+ z: torch.Tensor
+ codes: torch.Tensor
+ latents: torch.Tensor
+
+
+class DownsampleFiniteScalarQuantize(nn.Module):
+ def __init__(
+ self,
+ input_dim: int = 512,
+ n_codebooks: int = 9,
+ n_groups: int = 1,
+ levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
+ downsample_factor: tuple[int] = (2, 2),
+ downsample_dims: tuple[int] | None = None,
+ ):
+ super().__init__()
+
+ if downsample_dims is None:
+ downsample_dims = [input_dim for _ in range(len(downsample_factor))]
+
+ all_dims = (input_dim,) + tuple(downsample_dims)
+
+ self.residual_fsq = GroupedResidualFSQ(
+ dim=all_dims[-1],
+ levels=levels,
+ num_quantizers=n_codebooks,
+ groups=n_groups,
+ )
+
+ self.downsample_factor = downsample_factor
+ self.downsample_dims = downsample_dims
+
+ self.downsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ FishConvNet(
+ all_dims[idx],
+ all_dims[idx + 1],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
+ )
+ for idx, factor in enumerate(downsample_factor)
+ ]
+ )
+
+ self.upsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ FishTransConvNet(
+ all_dims[idx + 1],
+ all_dims[idx],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx]),
+ )
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
+ ]
+ )
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, z) -> FSQResult:
+ original_shape = z.shape
+ z = self.downsample(z)
+ quantized, indices = self.residual_fsq(z.mT)
+ result = FSQResult(
+ z=quantized.mT,
+ codes=indices.mT,
+ latents=z,
+ )
+ result.z = self.upsample(result.z)
+
+ # Pad or crop z to match original shape
+ diff = original_shape[-1] - result.z.shape[-1]
+ left = diff // 2
+ right = diff - left
+
+ if diff > 0:
+ result.z = F.pad(result.z, (left, right))
+ elif diff < 0:
+ result.z = result.z[..., -left:right]
+
+ return result
+
+ def encode(self, z):
+ z = self.downsample(z)
+ _, indices = self.residual_fsq(z.mT)
+ indices = rearrange(indices, "g b l r -> b (g r) l")
+ return indices
+
+ def decode(self, indices: torch.Tensor):
+ indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
+ z_q = self.residual_fsq.get_output_from_indices(indices)
+ z_q = self.upsample(z_q.mT)
+ return z_q
diff --git a/fish_speech/models/vqgan/utils.py b/fish_speech/models/vqgan/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b90c131d214006875476a161cdfd2dffa8949dac
--- /dev/null
+++ b/fish_speech/models/vqgan/utils.py
@@ -0,0 +1,94 @@
+import matplotlib
+import torch
+from matplotlib import pyplot as plt
+
+matplotlib.use("Agg")
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def plot_mel(data, titles=None):
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
+
+ if titles is None:
+ titles = [None for i in range(len(data))]
+
+ plt.tight_layout()
+
+ for i in range(len(data)):
+ mel = data[i]
+
+ if isinstance(mel, torch.Tensor):
+ mel = mel.float().detach().cpu().numpy()
+
+ axes[i][0].imshow(mel, origin="lower")
+ axes[i][0].set_aspect(2.5, adjustable="box")
+ axes[i][0].set_ylim(0, mel.shape[0])
+ axes[i][0].set_title(titles[i], fontsize="medium")
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
+ axes[i][0].set_anchor("W")
+
+ return fig
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
+ ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
+ n_channels_int = n_channels[0]
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+
+ return acts
+
+
+def avg_with_mask(x, mask):
+ assert mask.dtype == torch.float, "Mask should be float"
+
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(1)
+
+ if mask.shape[1] == 1:
+ mask = mask.expand_as(x)
+
+ return (x * mask).sum() / mask.sum()
diff --git a/fish_speech/scheduler.py b/fish_speech/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..43bed6a2210723a7d5e1ea0a48ba61140047ca29
--- /dev/null
+++ b/fish_speech/scheduler.py
@@ -0,0 +1,40 @@
+import math
+
+
+def get_cosine_schedule_with_warmup_lr_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int | float,
+ num_training_steps: int,
+ num_cycles: float = 0.5,
+ final_lr_ratio: float = 0.0,
+):
+ if 0 < num_warmup_steps < 1: # float mode
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+
+ progress = float(current_step - num_warmup_steps) / float(
+ max(1, num_training_steps - num_warmup_steps)
+ )
+
+ return max(
+ final_lr_ratio,
+ 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
+ )
+
+
+def get_constant_schedule_with_warmup_lr_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int | float,
+ num_training_steps: int | None = None,
+):
+ if 0 < num_warmup_steps < 1: # float mode
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+
+ return 1.0
diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d740bd8eed447d162e55b165965dec17130377ce
--- /dev/null
+++ b/fish_speech/text/__init__.py
@@ -0,0 +1,4 @@
+from .clean import clean_text
+from .spliter import split_text
+
+__all__ = ["clean_text", "split_text"]
diff --git a/fish_speech/text/__pycache__/__init__.cpython-310.pyc b/fish_speech/text/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..761a1d129e0b89a063a2e64510d64a5b75cee0c4
Binary files /dev/null and b/fish_speech/text/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/text/__pycache__/clean.cpython-310.pyc b/fish_speech/text/__pycache__/clean.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6cab7ec74d9132bcc449ec0085b399207646fa2
Binary files /dev/null and b/fish_speech/text/__pycache__/clean.cpython-310.pyc differ
diff --git a/fish_speech/text/__pycache__/spliter.cpython-310.pyc b/fish_speech/text/__pycache__/spliter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c5319ee21d6b7d82e35d081027a654ff7937ee3
Binary files /dev/null and b/fish_speech/text/__pycache__/spliter.cpython-310.pyc differ
diff --git a/fish_speech/text/chn_text_norm/.gitignore b/fish_speech/text/chn_text_norm/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..75ea58fa4a7bf34fc9ab35afee24684aa6ef4c89
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/.gitignore
@@ -0,0 +1,114 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# JetBrains PyCharm
+.idea
+
+# Customize
+references
+url.txt
+
+# Git
+.git
diff --git a/fish_speech/text/chn_text_norm/README.md b/fish_speech/text/chn_text_norm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8450a2c6c0f8e40f4509f5be196eb9f9d2b9afb6
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/README.md
@@ -0,0 +1,36 @@
+# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
+
+# Chn Text Norm
+
+this is a repository for chinese text normalization (no longer maintained).
+
+## Quick Start ##
+
+### Git Clone Repo ###
+
+git clone this repo to the root directory of your project which need to use it.
+
+ cd /path/to/proj
+ git clone https://github.com/Joee1995/chn-text-norm.git
+
+after that, your doc tree should be:
+```
+proj # root of your project
+|--- chn_text_norm # this chn-text-norm tool
+ |--- text.py
+ |--- ...
+|--- text_normalize.py # your text normalization code
+|--- ...
+```
+
+### How to Use ? ###
+
+ # text_normalize.py
+ from chn_text_norm.text import *
+
+ raw_text = 'your raw text'
+ text = Text(raw_text=raw_text).normalize()
+
+### How to add quantums ###
+
+打开test.py,然后你就知道怎么做了。
diff --git a/fish_speech/text/chn_text_norm/__init__.py b/fish_speech/text/chn_text_norm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/text/chn_text_norm/basic_class.py b/fish_speech/text/chn_text_norm/basic_class.py
new file mode 100644
index 0000000000000000000000000000000000000000..58d8f8eb7fc85d0861f106667d8f4e3e52b54761
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_class.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+"""基本类
+中文字符类
+中文数字/数位类
+中文数字类
+中文数位类
+中文数字系统类
+中文数学符号类
+*中文其他符号类
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
+
+
+class ChineseChar(object):
+ """
+ 中文字符
+ 每个字符对应简体和繁体,
+ e.g. 简体 = '负', 繁体 = '負'
+ 转换时可转换为简体或繁体
+ """
+
+ def __init__(self, simplified, traditional):
+ self.simplified = simplified
+ self.traditional = traditional
+ self.__repr__ = self.__str__
+
+ def __str__(self):
+ return self.simplified or self.traditional or None
+
+ def __repr__(self):
+ return self.__str__()
+
+
+class ChineseNumberUnit(ChineseChar):
+ """
+ 中文数字/数位字符
+ 每个字符除繁简体外还有一个额外的大写字符
+ e.g. '陆' 和 '陸'
+ """
+
+ def __init__(self, power, simplified, traditional, big_s, big_t):
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
+ self.power = power
+ self.big_s = big_s
+ self.big_t = big_t
+
+ def __str__(self):
+ return "10^{}".format(self.power)
+
+ @classmethod
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
+
+ if small_unit:
+ return ChineseNumberUnit(
+ power=index + 1,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[1],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[0]:
+ return ChineseNumberUnit(
+ power=index + 8,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[1]:
+ return ChineseNumberUnit(
+ power=(index + 2) * 4,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[2]:
+ return ChineseNumberUnit(
+ power=pow(2, index + 3),
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ else:
+ raise ValueError(
+ "Counting type should be in {0} ({1} provided).".format(
+ NUMBERING_TYPES, numbering_type
+ )
+ )
+
+
+class ChineseNumberDigit(ChineseChar):
+ """
+ 中文数字字符
+ """
+
+ def __init__(
+ self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
+ ):
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
+ self.value = value
+ self.big_s = big_s
+ self.big_t = big_t
+ self.alt_s = alt_s
+ self.alt_t = alt_t
+
+ def __str__(self):
+ return str(self.value)
+
+ @classmethod
+ def create(cls, i, v):
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
+
+
+class ChineseMath(ChineseChar):
+ """
+ 中文数位字符
+ """
+
+ def __init__(self, simplified, traditional, symbol, expression=None):
+ super(ChineseMath, self).__init__(simplified, traditional)
+ self.symbol = symbol
+ self.expression = expression
+ self.big_s = simplified
+ self.big_t = traditional
+
+
+CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
+
+
+class NumberSystem(object):
+ """
+ 中文数字系统
+ """
+
+ pass
+
+
+class MathSymbol(object):
+ """
+ 用于中文数字系统的数学符号 (繁/简体), e.g.
+ positive = ['正', '正']
+ negative = ['负', '負']
+ point = ['点', '點']
+ """
+
+ def __init__(self, positive, negative, point):
+ self.positive = positive
+ self.negative = negative
+ self.point = point
+
+ def __iter__(self):
+ for v in self.__dict__.values():
+ yield v
+
+
+# class OtherSymbol(object):
+# """
+# 其他符号
+# """
+#
+# def __init__(self, sil):
+# self.sil = sil
+#
+# def __iter__(self):
+# for v in self.__dict__.values():
+# yield v
diff --git a/fish_speech/text/chn_text_norm/basic_constant.py b/fish_speech/text/chn_text_norm/basic_constant.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a65991b9a9d349a0571c80508633951e52749ef
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_constant.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+"""基本常量
+中文数字/数位/符号字符常量
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+CHINESE_DIGIS = "零一二三四五六七八九"
+BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
+BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
+SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
+SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
+LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
+LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
+SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
+SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
+
+ZERO_ALT = "〇"
+ONE_ALT = "幺"
+TWO_ALTS = ["两", "兩"]
+
+POSITIVE = ["正", "正"]
+NEGATIVE = ["负", "負"]
+POINT = ["点", "點"]
+# PLUS = [u'加', u'加']
+# SIL = [u'杠', u'槓']
+
+# 中文数字系统类型
+NUMBERING_TYPES = ["low", "mid", "high"]
diff --git a/fish_speech/text/chn_text_norm/basic_util.py b/fish_speech/text/chn_text_norm/basic_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbf6130be87f285eed9998186508ea489d3bac9e
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_util.py
@@ -0,0 +1,342 @@
+# -*- coding: utf-8 -*-
+"""基本方法
+创建中文数字系统 方法
+中文字符串 <=> 数字串 方法
+数字串 <=> 中文字符串 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+from fish_speech.text.chn_text_norm.basic_class import *
+from fish_speech.text.chn_text_norm.basic_constant import *
+
+
+def create_system(numbering_type=NUMBERING_TYPES[1]):
+ """
+ 根据数字系统类型返回创建相应的数字系统,默认为 mid
+ NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
+ low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
+ mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
+ high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
+ 返回对应的数字系统
+ """
+
+ # chinese number units of '亿' and larger
+ all_larger_units = zip(
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL,
+ )
+ larger_units = [
+ CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)
+ ]
+ # chinese number units of '十, 百, 千, 万'
+ all_smaller_units = zip(
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL,
+ )
+ smaller_units = [
+ CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)
+ ]
+ # digis
+ chinese_digis = zip(
+ CHINESE_DIGIS,
+ CHINESE_DIGIS,
+ BIG_CHINESE_DIGIS_SIMPLIFIED,
+ BIG_CHINESE_DIGIS_TRADITIONAL,
+ )
+ digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
+ digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
+ digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
+ digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
+
+ # symbols
+ positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
+ negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
+ point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
+ # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
+ system = NumberSystem()
+ system.units = smaller_units + larger_units
+ system.digits = digits
+ system.math = MathSymbol(positive_cn, negative_cn, point_cn)
+ # system.symbols = OtherSymbol(sil_cn)
+ return system
+
+
+def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
+
+ def get_symbol(char, system):
+ for u in system.units:
+ if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
+ return u
+ for d in system.digits:
+ if char in [
+ d.traditional,
+ d.simplified,
+ d.big_s,
+ d.big_t,
+ d.alt_s,
+ d.alt_t,
+ ]:
+ return d
+ for m in system.math:
+ if char in [m.traditional, m.simplified]:
+ return m
+
+ def string2symbols(chinese_string, system):
+ int_string, dec_string = chinese_string, ""
+ for p in [system.math.point.simplified, system.math.point.traditional]:
+ if p in chinese_string:
+ int_string, dec_string = chinese_string.split(p)
+ break
+ return [get_symbol(c, system) for c in int_string], [
+ get_symbol(c, system) for c in dec_string
+ ]
+
+ def correct_symbols(integer_symbols, system):
+ """
+ 一百八 to 一百八十
+ 一亿一千三百万 to 一亿 一千万 三百万
+ """
+
+ if integer_symbols and isinstance(integer_symbols[0], CNU):
+ if integer_symbols[0].power == 1:
+ integer_symbols = [system.digits[1]] + integer_symbols
+
+ if len(integer_symbols) > 1:
+ if isinstance(integer_symbols[-1], CND) and isinstance(
+ integer_symbols[-2], CNU
+ ):
+ integer_symbols.append(
+ CNU(integer_symbols[-2].power - 1, None, None, None, None)
+ )
+
+ result = []
+ unit_count = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ result.append(s)
+ unit_count = 0
+ elif isinstance(s, CNU):
+ current_unit = CNU(s.power, None, None, None, None)
+ unit_count += 1
+
+ if unit_count == 1:
+ result.append(current_unit)
+ elif unit_count > 1:
+ for i in range(len(result)):
+ if (
+ isinstance(result[-i - 1], CNU)
+ and result[-i - 1].power < current_unit.power
+ ):
+ result[-i - 1] = CNU(
+ result[-i - 1].power + current_unit.power,
+ None,
+ None,
+ None,
+ None,
+ )
+ return result
+
+ def compute_value(integer_symbols):
+ """
+ Compute the value.
+ When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
+ e.g. '两千万' = 2000 * 10000 not 2000 + 10000
+ """
+ value = [0]
+ last_power = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ value[-1] = s.value
+ elif isinstance(s, CNU):
+ value[-1] *= pow(10, s.power)
+ if s.power > last_power:
+ value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
+ last_power = s.power
+ value.append(0)
+ return sum(value)
+
+ system = create_system(numbering_type)
+ int_part, dec_part = string2symbols(chinese_string, system)
+ int_part = correct_symbols(int_part, system)
+ int_str = str(compute_value(int_part))
+ dec_str = "".join([str(d.value) for d in dec_part])
+ if dec_part:
+ return "{0}.{1}".format(int_str, dec_str)
+ else:
+ return int_str
+
+
+def num2chn(
+ number_string,
+ numbering_type=NUMBERING_TYPES[1],
+ big=False,
+ traditional=False,
+ alt_zero=False,
+ alt_one=False,
+ alt_two=True,
+ use_zeros=True,
+ use_units=True,
+):
+
+ def get_value(value_string, use_zeros=True):
+
+ striped_string = value_string.lstrip("0")
+
+ # record nothing if all zeros
+ if not striped_string:
+ return []
+
+ # record one digits
+ elif len(striped_string) == 1:
+ if use_zeros and len(value_string) != len(striped_string):
+ return [system.digits[0], system.digits[int(striped_string)]]
+ else:
+ return [system.digits[int(striped_string)]]
+
+ # recursively record multiple digits
+ else:
+ result_unit = next(
+ u for u in reversed(system.units) if u.power < len(striped_string)
+ )
+ result_string = value_string[: -result_unit.power]
+ return (
+ get_value(result_string)
+ + [result_unit]
+ + get_value(striped_string[-result_unit.power :])
+ )
+
+ system = create_system(numbering_type)
+
+ int_dec = number_string.split(".")
+ if len(int_dec) == 1:
+ int_string = int_dec[0]
+ dec_string = ""
+ elif len(int_dec) == 2:
+ int_string = int_dec[0]
+ dec_string = int_dec[1]
+ else:
+ raise ValueError(
+ "invalid input num string with more than one dot: {}".format(number_string)
+ )
+
+ if use_units and len(int_string) > 1:
+ result_symbols = get_value(int_string)
+ else:
+ result_symbols = [system.digits[int(c)] for c in int_string]
+ dec_symbols = [system.digits[int(c)] for c in dec_string]
+ if dec_string:
+ result_symbols += [system.math.point] + dec_symbols
+
+ if alt_two:
+ liang = CND(
+ 2,
+ system.digits[2].alt_s,
+ system.digits[2].alt_t,
+ system.digits[2].big_s,
+ system.digits[2].big_t,
+ )
+ for i, v in enumerate(result_symbols):
+ if isinstance(v, CND) and v.value == 2:
+ next_symbol = (
+ result_symbols[i + 1] if i < len(result_symbols) - 1 else None
+ )
+ previous_symbol = result_symbols[i - 1] if i > 0 else None
+ if isinstance(next_symbol, CNU) and isinstance(
+ previous_symbol, (CNU, type(None))
+ ):
+ if next_symbol.power != 1 and (
+ (previous_symbol is None) or (previous_symbol.power != 1)
+ ):
+ result_symbols[i] = liang
+
+ # if big is True, '两' will not be used and `alt_two` has no impact on output
+ if big:
+ attr_name = "big_"
+ if traditional:
+ attr_name += "t"
+ else:
+ attr_name += "s"
+ else:
+ if traditional:
+ attr_name = "traditional"
+ else:
+ attr_name = "simplified"
+
+ result = "".join([getattr(s, attr_name) for s in result_symbols])
+
+ # if not use_zeros:
+ # result = result.strip(getattr(system.digits[0], attr_name))
+
+ if alt_zero:
+ result = result.replace(
+ getattr(system.digits[0], attr_name), system.digits[0].alt_s
+ )
+
+ if alt_one:
+ result = result.replace(
+ getattr(system.digits[1], attr_name), system.digits[1].alt_s
+ )
+
+ for i, p in enumerate(POINT):
+ if result.startswith(p):
+ return CHINESE_DIGIS[0] + result
+
+ # ^10, 11, .., 19
+ if (
+ len(result) >= 2
+ and result[1]
+ in [
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
+ ]
+ and result[0]
+ in [
+ CHINESE_DIGIS[1],
+ BIG_CHINESE_DIGIS_SIMPLIFIED[1],
+ BIG_CHINESE_DIGIS_TRADITIONAL[1],
+ ]
+ ):
+ result = result[1:]
+
+ return result
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ all_chinese_number_string = (
+ CHINESE_DIGIS
+ + BIG_CHINESE_DIGIS_SIMPLIFIED
+ + BIG_CHINESE_DIGIS_TRADITIONAL
+ + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ + ZERO_ALT
+ + ONE_ALT
+ + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT)
+ )
+
+ print("num:", chn2num("一万零四百零三点八零五"))
+ print("num:", chn2num("一亿六点三"))
+ print("num:", chn2num("一亿零六点三"))
+ print("num:", chn2num("两千零一亿六点三"))
+ # print('num:', chn2num('一零零八六'))
+ print("txt:", num2chn("10260.03", alt_zero=True))
+ print("txt:", num2chn("20037.090", numbering_type="low", traditional=True))
+ print("txt:", num2chn("100860001.77", numbering_type="high", big=True))
+ print(
+ "txt:",
+ num2chn(
+ "059523810880",
+ alt_one=True,
+ alt_two=False,
+ use_lzeros=True,
+ use_rzeros=True,
+ use_units=False,
+ ),
+ )
+
+ print(all_chinese_number_string)
diff --git a/fish_speech/text/chn_text_norm/cardinal.py b/fish_speech/text/chn_text_norm/cardinal.py
new file mode 100644
index 0000000000000000000000000000000000000000..ace9f5ad8e7f3be3a8e41b11dc0b9f80db799616
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/cardinal.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+"""CARDINAL类 (包含小数DECIMAL类)
+纯数 <=> 中文字符串 方法
+中文字符串 <=> 纯数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Cardinal:
+ """
+ CARDINAL类
+ """
+
+ def __init__(self, cardinal=None, chntext=None):
+ self.cardinal = cardinal
+ self.chntext = chntext
+
+ def chntext2cardinal(self):
+ return chn2num(self.chntext)
+
+ def cardinal2chntext(self):
+ return num2chn(self.cardinal)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Cardinal(cardinal="21357.230").cardinal2chntext())
diff --git a/fish_speech/text/chn_text_norm/date.py b/fish_speech/text/chn_text_norm/date.py
new file mode 100644
index 0000000000000000000000000000000000000000..77acfdb9a91df0fe3c615a0784f61aad87fbe56e
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/date.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+"""DATE类
+日期 <=> 中文字符串 方法
+中文字符串 <=> 日期 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-07"
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+from fish_speech.text.chn_text_norm.digit import Digit
+
+
+class Date:
+ """
+ DATE类
+ """
+
+ def __init__(self, date=None, chntext=None):
+ self.date = date
+ self.chntext = chntext
+
+ # def chntext2date(self):
+ # chntext = self.chntext
+ # try:
+ # year, other = chntext.strip().split('年', maxsplit=1)
+ # year = Digit(chntext=year).digit2chntext() + '年'
+ # except ValueError:
+ # other = chntext
+ # year = ''
+ # if other:
+ # try:
+ # month, day = other.strip().split('月', maxsplit=1)
+ # month = Cardinal(chntext=month).chntext2cardinal() + '月'
+ # except ValueError:
+ # day = chntext
+ # month = ''
+ # if day:
+ # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
+ # else:
+ # month = ''
+ # day = ''
+ # date = year + month + day
+ # self.date = date
+ # return self.date
+
+ def date2chntext(self):
+ date = self.date
+ try:
+ year, other = date.strip().split("年", maxsplit=1)
+ year = Digit(digit=year).digit2chntext() + "年"
+ except ValueError:
+ other = date
+ year = ""
+ if other:
+ try:
+ month, day = other.strip().split("月", maxsplit=1)
+ month = Cardinal(cardinal=month).cardinal2chntext() + "月"
+ except ValueError:
+ day = date
+ month = ""
+ if day:
+ day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
+ else:
+ month = ""
+ day = ""
+ chntext = year + month + day
+ self.chntext = chntext
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试
+ print(Date(date="09年3月16日").date2chntext())
diff --git a/fish_speech/text/chn_text_norm/digit.py b/fish_speech/text/chn_text_norm/digit.py
new file mode 100644
index 0000000000000000000000000000000000000000..47c0cd4ad0c700635f84470bfdacfbdafb4a6185
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/digit.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+"""DIGIT类
+数字串 <=> 中文字符串 方法
+中文字符串 <=> 数字串 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Digit:
+ """
+ DIGIT类
+ """
+
+ def __init__(self, digit=None, chntext=None):
+ self.digit = digit
+ self.chntext = chntext
+
+ # def chntext2digit(self):
+ # return chn2num(self.chntext)
+
+ def digit2chntext(self):
+ return num2chn(self.digit, alt_two=False, use_units=False)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Digit(digit="2016").digit2chntext())
diff --git a/fish_speech/text/chn_text_norm/fraction.py b/fish_speech/text/chn_text_norm/fraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..b43b6a7feb634d346d59a2b4ab84b77ac88df103
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/fraction.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+"""FRACTION类
+分数 <=> 中文字符串 方法
+中文字符串 <=> 分数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Fraction:
+ """
+ FRACTION类
+ """
+
+ def __init__(self, fraction=None, chntext=None):
+ self.fraction = fraction
+ self.chntext = chntext
+
+ def chntext2fraction(self):
+ denominator, numerator = self.chntext.split("分之")
+ return chn2num(numerator) + "/" + chn2num(denominator)
+
+ def fraction2chntext(self):
+ numerator, denominator = self.fraction.split("/")
+ return num2chn(denominator) + "分之" + num2chn(numerator)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Fraction(fraction="2135/7230").fraction2chntext())
+ print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction())
diff --git a/fish_speech/text/chn_text_norm/money.py b/fish_speech/text/chn_text_norm/money.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c980d32134e1460e96e5bcbcc73d0d55974d2a
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/money.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+"""MONEY类
+金钱 <=> 中文字符串 方法
+中文字符串 <=> 金钱 方法
+"""
+import re
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-08"
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+
+
+class Money:
+ """
+ MONEY类
+ """
+
+ def __init__(self, money=None, chntext=None):
+ self.money = money
+ self.chntext = chntext
+
+ # def chntext2money(self):
+ # return self.money
+
+ def money2chntext(self):
+ money = self.money
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(money)
+ if matchers:
+ for matcher in matchers:
+ money = money.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()
+ )
+ self.chntext = money
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试
+ print(Money(money="21.5万元").money2chntext())
+ print(Money(money="230块5毛").money2chntext())
diff --git a/fish_speech/text/chn_text_norm/percentage.py b/fish_speech/text/chn_text_norm/percentage.py
new file mode 100644
index 0000000000000000000000000000000000000000..46abbf545af62eb951d8f6fe40bcf684587f81b0
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/percentage.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+"""PERCENTAGE类
+百分数 <=> 中文字符串 方法
+中文字符串 <=> 百分数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-06"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Percentage:
+ """
+ PERCENTAGE类
+ """
+
+ def __init__(self, percentage=None, chntext=None):
+ self.percentage = percentage
+ self.chntext = chntext
+
+ def chntext2percentage(self):
+ return chn2num(self.chntext.strip().strip("百分之")) + "%"
+
+ def percentage2chntext(self):
+ return "百分之" + num2chn(self.percentage.strip().strip("%"))
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Percentage(chntext="百分之五十六点零三").chntext2percentage())
+ print(Percentage(percentage="65.3%").percentage2chntext())
diff --git a/fish_speech/text/chn_text_norm/telephone.py b/fish_speech/text/chn_text_norm/telephone.py
new file mode 100644
index 0000000000000000000000000000000000000000..e72b546db628a3b807dc6235b59b188cae3153ff
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/telephone.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+"""TELEPHONE类
+电话号码 <=> 中文字符串 方法
+中文字符串 <=> 电话号码 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class TelePhone:
+ """
+ TELEPHONE类
+ """
+
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
+ self.telephone = telephone
+ self.raw_chntext = raw_chntext
+ self.chntext = chntext
+
+ # def chntext2telephone(self):
+ # sil_parts = self.raw_chntext.split('')
+ # self.telephone = '-'.join([
+ # str(chn2num(p)) for p in sil_parts
+ # ])
+ # return self.telephone
+
+ def telephone2chntext(self, fixed=False):
+
+ if fixed:
+ sil_parts = self.telephone.split("-")
+ self.raw_chntext = "".join(
+ [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
+ )
+ self.chntext = self.raw_chntext.replace("", "")
+ else:
+ sp_parts = self.telephone.strip("+").split()
+ self.raw_chntext = "".join(
+ [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
+ )
+ self.chntext = self.raw_chntext.replace("", "")
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(TelePhone(telephone="0595-23980880").telephone2chntext())
+ # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone())
diff --git a/fish_speech/text/chn_text_norm/text.py b/fish_speech/text/chn_text_norm/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..54086fd933c01e14c3c55cee9adb52eefb58fd31
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/text.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+"""
+TEXT类
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+import re
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+from fish_speech.text.chn_text_norm.date import Date
+from fish_speech.text.chn_text_norm.digit import Digit
+from fish_speech.text.chn_text_norm.fraction import Fraction
+from fish_speech.text.chn_text_norm.money import Money
+from fish_speech.text.chn_text_norm.percentage import Percentage
+from fish_speech.text.chn_text_norm.telephone import TelePhone
+
+CURRENCY_NAMES = (
+ "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
+ "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
+)
+CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
+COM_QUANTIFIERS = (
+ "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
+ "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
+ "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
+ "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
+ "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
+ "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)"
+)
+
+
+class Text:
+ """
+ Text类
+ """
+
+ def __init__(self, raw_text, norm_text=None):
+ self.raw_text = "^" + raw_text + "$"
+ self.norm_text = norm_text
+
+ def _particular(self):
+ text = self.norm_text
+ pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('particular')
+ for matcher in matchers:
+ text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
+ self.norm_text = text
+ return self.norm_text
+
+ def normalize(self):
+ text = self.raw_text
+
+ # 规范化日期
+ pattern = re.compile(
+ r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)"
+ )
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('date')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
+
+ # 规范化金钱
+ pattern = re.compile(
+ r"\D+((\d+(\.\d+)?)[多余几]?"
+ + CURRENCY_UNITS
+ + "(\d"
+ + CURRENCY_UNITS
+ + "?)?)"
+ )
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('money')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Money(money=matcher[0]).money2chntext(), 1
+ )
+
+ # 规范化固话/手机号码
+ # 手机
+ # http://www.jihaoba.com/news/show/13680
+ # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
+ # 联通:130、131、132、156、155、186、185、176
+ # 电信:133、153、189、180、181、177
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('telephone')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
+ )
+ # 固话
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fixed telephone')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0],
+ TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True),
+ 1,
+ )
+
+ # 规范化分数
+ pattern = re.compile(r"(\d+/\d+)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fraction')
+ for matcher in matchers:
+ text = text.replace(
+ matcher, Fraction(fraction=matcher).fraction2chntext(), 1
+ )
+
+ # 规范化百分数
+ text = text.replace("%", "%")
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('percentage')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0],
+ Percentage(percentage=matcher[0]).percentage2chntext(),
+ 1,
+ )
+
+ # 规范化纯数+量词
+ pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal+quantifier')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
+ )
+
+ # 规范化数字编号
+ pattern = re.compile(r"(\d{4,32})")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('digit')
+ for matcher in matchers:
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
+
+ # 规范化纯数
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
+ )
+
+ self.norm_text = text
+ self._particular()
+
+ return self.norm_text.lstrip("^").rstrip("$")
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Text(raw_text="固话:0595-23865596或23880880。").normalize())
+ print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize())
+ print(Text(raw_text="分数:32477/76391。").normalize())
+ print(Text(raw_text="百分数:80.03%。").normalize())
+ print(Text(raw_text="编号:31520181154418。").normalize())
+ print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize())
+ print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize())
+ print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize())
+ print(Text(raw_text="特殊:O2O或B2C。").normalize())
diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe1b331cd0a153fd55a4d598cf0aa79fbc5cbfc6
--- /dev/null
+++ b/fish_speech/text/clean.py
@@ -0,0 +1,37 @@
+import re
+
+SYMBOLS_MAPPING = {
+ "‘": "'",
+ "’": "'",
+}
+
+REPLACE_SYMBOL_REGEX = re.compile(
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
+)
+
+
+EMOJI_REGEX = re.compile(
+ "["
+ "\U0001F600-\U0001F64F" # emoticons
+ "\U0001F300-\U0001F5FF" # symbols & pictographs
+ "\U0001F680-\U0001F6FF" # transport & map symbols
+ "\U0001F1E0-\U0001F1FF" # flags (iOS)
+ "]+",
+ flags=re.UNICODE,
+)
+
+
+def clean_text(text):
+ # Clean the text
+ text = text.strip()
+
+ # Replace all chinese symbols with their english counterparts
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
+
+ # Remove emojis
+ text = EMOJI_REGEX.sub(r"", text)
+
+ # Remove continuous periods (...) and commas (,,,)
+ text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
+
+ return text
diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py
new file mode 100644
index 0000000000000000000000000000000000000000..df079addb81cd91145f0c68f70b0da0d7251f036
--- /dev/null
+++ b/fish_speech/text/spliter.py
@@ -0,0 +1,130 @@
+import re
+import string
+
+from fish_speech.text.clean import clean_text
+
+
+def utf_8_len(text: str):
+ return len(text.encode("utf-8"))
+
+
+def break_text(texts, length, splits: set):
+ for text in texts:
+ if utf_8_len(text) <= length:
+ yield text
+ continue
+
+ curr = ""
+ for char in text:
+ curr += char
+
+ if char in splits:
+ yield curr
+ curr = ""
+
+ if curr:
+ yield curr
+
+
+def break_text_by_length(texts, length):
+ for text in texts:
+ if utf_8_len(text) <= length:
+ yield text
+ continue
+
+ curr = ""
+ for char in text:
+ curr += char
+
+ if utf_8_len(curr) >= length:
+ yield curr
+ curr = ""
+
+ if curr:
+ yield curr
+
+
+def add_cleaned(curr, segments):
+ curr = curr.strip()
+ if curr and not all(c.isspace() or c in string.punctuation for c in curr):
+ segments.append(curr)
+
+
+def protect_float(text):
+ # Turns 3.14 into <3_f_14> to prevent splitting
+ return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
+
+
+def unprotect_float(text):
+ # Turns <3_f_14> into 3.14
+ return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
+
+
+def split_text(text, length):
+ text = clean_text(text)
+
+ # Break the text into pieces with following rules:
+ # 1. Split the text at ".", "!", "?" if text is NOT a float
+ # 2. If the text is longer than length, split at ","
+ # 3. If the text is still longer than length, split at " "
+ # 4. If the text is still longer than length, split at any character to length
+
+ texts = [text]
+ texts = map(protect_float, texts)
+ texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
+ texts = map(unprotect_float, texts)
+ texts = break_text(texts, length, {",", ","})
+ texts = break_text(texts, length, {" "})
+ texts = list(break_text_by_length(texts, length))
+
+ # Then, merge the texts into segments with length <= length
+ segments = []
+ curr = ""
+
+ for text in texts:
+ if utf_8_len(curr) + utf_8_len(text) <= length:
+ curr += text
+ else:
+ add_cleaned(curr, segments)
+ curr = text
+
+ if curr:
+ add_cleaned(curr, segments)
+
+ return segments
+
+
+if __name__ == "__main__":
+ # Test the split_text function
+
+ text = "This is a test sentence. This is another test sentence. And a third one."
+
+ assert split_text(text, 50) == [
+ "This is a test sentence.",
+ "This is another test sentence. And a third one.",
+ ]
+ assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
+ assert split_text(" ", 10) == []
+ assert split_text("a", 10) == ["a"]
+
+ text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
+ assert split_text(text, 50) == [
+ "This is a test sentence with only commas,",
+ "and no dots, and no exclamation marks,",
+ "and no question marks, and no newlines.",
+ ]
+
+ text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
+ # First half split at " ", second half split at ","
+ assert split_text(text, 50) == [
+ "This is a test sentence This is a test sentence",
+ "This is a test sentence. This is a test sentence,",
+ "This is a test sentence, This is a test sentence.",
+ ]
+
+ text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
+ assert split_text(text, 50) == [
+ "这是一段很长的中文文本,",
+ "而且没有句号,也没有感叹号,",
+ "也没有问号,也没有换行符.",
+ ]
diff --git a/fish_speech/tokenizer.py b/fish_speech/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..897cd060908e35c07ade8d37ee28ac3d8784f891
--- /dev/null
+++ b/fish_speech/tokenizer.py
@@ -0,0 +1,152 @@
+import base64
+import json
+import logging
+from pathlib import Path
+
+import tiktoken
+
+logger = logging.getLogger(__name__)
+
+# This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
+FISH_TIKTOKEN_PATTERN = "|".join(
+ [
+ r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
+ r"\p{P}",
+ r"[^\r\n\p{L}\p{N}]?\p{L}+",
+ r"\p{N}",
+ r" ?[^\s\p{L}\p{N}]+[\r\n]*",
+ r"\s*[\r\n]+",
+ r"\s+(\?!\S)",
+ r"\s+",
+ ]
+)
+TIKTOKEN_MAX_ENCODE_CHARS = 400_000
+
+BOS_TOKEN = "<|begin_of_text|>"
+EOS_TOKEN = "<|end_of_text|>"
+PAD_TOKEN = "<|pad|>"
+IM_START_TOKEN = "<|im_start|>"
+IM_END_TOKEN = "<|im_end|>"
+
+MODALITY_TEXT_TOKEN = "<|text|>"
+MODALITY_VOICE_TOKEN = "<|voice|>"
+MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
+MODALITY_TOKENS = {
+ "text": MODALITY_TEXT_TOKEN,
+ "voice": MODALITY_VOICE_TOKEN,
+ "interleave": MODALITY_INTERLEAVE_TOKEN,
+}
+
+PLACEHOLDER_TOKEN = [""] * 4
+for i in range(4):
+ PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>"
+
+SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
+SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
+
+# Warning: when you add a new special token, you should only add it to the end of the list.
+ALL_SPECIAL_TOKENS = [
+ BOS_TOKEN,
+ EOS_TOKEN,
+ PAD_TOKEN,
+ IM_START_TOKEN,
+ IM_END_TOKEN,
+ PLACEHOLDER_TOKEN[0],
+ PLACEHOLDER_TOKEN[1],
+ PLACEHOLDER_TOKEN[2],
+ PLACEHOLDER_TOKEN[3],
+ MODALITY_TEXT_TOKEN,
+ MODALITY_VOICE_TOKEN,
+ MODALITY_INTERLEAVE_TOKEN,
+ *SEMANTIC_TOKENS,
+]
+
+
+class FishTokenizer:
+ def __init__(self, model_path: str) -> None:
+ mergeable_ranks = self.load_tiktoken_bpe(model_path)
+ special_token_begin = len(mergeable_ranks)
+ self.all_special_tokens_with_ids = {
+ token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS)
+ }
+ self.semantic_id_to_token_id = {
+ i: self.all_special_tokens_with_ids[token]
+ for i, token in enumerate(SEMANTIC_TOKENS)
+ }
+ self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]]
+ self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]]
+
+ self.tkt_model = tiktoken.core.Encoding(
+ name=Path(model_path).stem,
+ pat_str=FISH_TIKTOKEN_PATTERN,
+ mergeable_ranks=mergeable_ranks,
+ special_tokens=self.all_special_tokens_with_ids,
+ )
+
+ @staticmethod
+ def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
+ data = {}
+ for line in open(tiktoken_bpe_file).read().splitlines():
+ if not line:
+ continue
+ token, rank = line.split()
+ data[base64.b64decode(token)] = int(rank)
+ return data
+
+ def get_token_id(self, token: str) -> int:
+ return self.all_special_tokens_with_ids[token]
+
+ def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
+ assert isinstance(s, str)
+
+ subs = []
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
+ subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
+
+ if allowed_special is True:
+ allowed_special = self.tkt_model.special_tokens_set
+ elif allowed_special is False:
+ allowed_special = set()
+
+ return sum(
+ self.tkt_model.encode_batch(
+ subs, allowed_special=allowed_special, disallowed_special=set()
+ ),
+ start=[],
+ )
+
+ def decode(self, tokens: list[int]) -> str:
+ return self.tkt_model.decode(tokens)
+
+ def save_pretrained(self, path: str):
+ path = Path(path)
+ path.mkdir(parents=True, exist_ok=True)
+
+ with open(path / "tokenizer.tiktoken", "w") as f:
+ for token, rank in self.tkt_model._mergeable_ranks.items():
+ f.write(f"{base64.b64encode(token).decode()} {rank}\n")
+
+ with open(path / "special_tokens.json", "w") as f:
+ json.dump(
+ self.all_special_tokens_with_ids,
+ f,
+ indent=2,
+ ensure_ascii=False,
+ )
+
+ @staticmethod
+ def from_pretrained(path: str):
+ return FishTokenizer(Path(path) / "tokenizer.tiktoken")
+
+
+if __name__ == "__main__":
+ tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken")
+ tokenizer.save_pretrained("checkpoints/fish-speech-0.5B")
+ tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B")
+
+ print(
+ [
+ tokenizer.decode([i])
+ for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}")
+ ]
+ )
diff --git a/fish_speech/train.py b/fish_speech/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e693f3adc4dda787bdd587aec29f53355f2b1653
--- /dev/null
+++ b/fish_speech/train.py
@@ -0,0 +1,141 @@
+import os
+
+os.environ["USE_LIBUV"] = "0"
+import sys
+from typing import Optional
+
+import hydra
+import lightning as L
+import pyrootutils
+import torch
+from lightning import Callback, LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
+from lightning.pytorch.strategies import DDPStrategy
+from omegaconf import DictConfig, OmegaConf
+
+os.environ.pop("SLURM_NTASKS", None)
+os.environ.pop("SLURM_JOB_NAME", None)
+os.environ.pop("SLURM_NTASKS_PER_NODE", None)
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+# Allow TF32 on Ampere GPUs
+torch.set_float32_matmul_precision("high")
+torch.backends.cudnn.allow_tf32 = True
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+import fish_speech.utils as utils
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+@utils.task_wrapper
+def train(cfg: DictConfig) -> tuple[dict, dict]:
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
+ training.
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ Returns:
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+ """ # noqa: E501
+
+ # set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ L.seed_everything(cfg.seed, workers=False)
+
+ if cfg.get("deterministic"):
+ torch.use_deterministic_algorithms(True)
+
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+ log.info("Instantiating callbacks...")
+ callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
+
+ log.info("Instantiating loggers...")
+ logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(
+ cfg.trainer,
+ callbacks=callbacks,
+ logger=logger,
+ )
+
+ object_dict = {
+ "cfg": cfg,
+ "datamodule": datamodule,
+ "model": model,
+ "callbacks": callbacks,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ utils.log_hyperparameters(object_dict)
+
+ if cfg.get("train"):
+ log.info("Starting training!")
+
+ ckpt_path = cfg.get("ckpt_path")
+ auto_resume = False
+
+ resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
+ if resume_ckpt_path is not None:
+ ckpt_path = resume_ckpt_path
+ auto_resume = True
+
+ if ckpt_path is not None:
+ log.info(f"Resuming from checkpoint: {ckpt_path}")
+
+ # resume weights only is disabled for auto-resume
+ if cfg.get("resume_weights_only") and auto_resume is False:
+ log.info("Resuming weights only!")
+ ckpt = torch.load(ckpt_path, map_location=model.device)
+ if "state_dict" in ckpt:
+ ckpt = ckpt["state_dict"]
+ err = model.load_state_dict(ckpt, strict=False)
+ log.info(f"Error loading state dict: {err}")
+ ckpt_path = None
+
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+
+ train_metrics = trainer.callback_metrics
+
+ if cfg.get("test"):
+ log.info("Starting testing!")
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+ if ckpt_path == "":
+ log.warning("Best ckpt not found! Using current weights for testing...")
+ ckpt_path = cfg.get("ckpt_path")
+
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+ log.info(f"Best ckpt path: {ckpt_path}")
+
+ test_metrics = trainer.callback_metrics
+
+ # merge train and test metrics
+ metric_dict = {**train_metrics, **test_metrics}
+
+ return metric_dict, object_dict
+
+
+@hydra.main(
+ version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
+)
+def main(cfg: DictConfig) -> Optional[float]:
+ # train the model
+ train(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53cf2f23174ddac9bf523730aca2f6a9965d134a
--- /dev/null
+++ b/fish_speech/utils/__init__.py
@@ -0,0 +1,24 @@
+from .braceexpand import braceexpand
+from .context import autocast_exclude_mps
+from .file import get_latest_checkpoint
+from .instantiators import instantiate_callbacks, instantiate_loggers
+from .logger import RankedLogger
+from .logging_utils import log_hyperparameters
+from .rich_utils import enforce_tags, print_config_tree
+from .utils import extras, get_metric_value, set_seed, task_wrapper
+
+__all__ = [
+ "enforce_tags",
+ "extras",
+ "get_metric_value",
+ "RankedLogger",
+ "instantiate_callbacks",
+ "instantiate_loggers",
+ "log_hyperparameters",
+ "print_config_tree",
+ "task_wrapper",
+ "braceexpand",
+ "get_latest_checkpoint",
+ "autocast_exclude_mps",
+ "set_seed",
+]
diff --git a/fish_speech/utils/__pycache__/__init__.cpython-310.pyc b/fish_speech/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..969029da68ef2dbffd57ad2f18ca64871b2c232b
Binary files /dev/null and b/fish_speech/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc b/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56312736b1749b7e98e67e88d9e97c296c396153
Binary files /dev/null and b/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/context.cpython-310.pyc b/fish_speech/utils/__pycache__/context.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c403be8ac86e686d69d0fea0e196ad135b0d46c
Binary files /dev/null and b/fish_speech/utils/__pycache__/context.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/file.cpython-310.pyc b/fish_speech/utils/__pycache__/file.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21aa25de8dbe2db4073fe643bc207245c23a020f
Binary files /dev/null and b/fish_speech/utils/__pycache__/file.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc b/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87217f9fa3840beb8a3b08fe3d3999abd9358497
Binary files /dev/null and b/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/logger.cpython-310.pyc b/fish_speech/utils/__pycache__/logger.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..657b505ac10a27dbb6d298141b639f8b3bb83afe
Binary files /dev/null and b/fish_speech/utils/__pycache__/logger.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2f015147647aed509a7060b16757cdbe1934f52
Binary files /dev/null and b/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0bf91a12c13e71f71ba9e64ac008c13c3e83446
Binary files /dev/null and b/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc b/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f4fae227944024e3491170bc3802eed83d1476a
Binary files /dev/null and b/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc differ
diff --git a/fish_speech/utils/__pycache__/utils.cpython-310.pyc b/fish_speech/utils/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83c8dc55984ebf108a5d60cb17269205d9dbadc2
Binary files /dev/null and b/fish_speech/utils/__pycache__/utils.cpython-310.pyc differ
diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3ac739f01f7e10e039c68c1157d6c761064f974
--- /dev/null
+++ b/fish_speech/utils/braceexpand.py
@@ -0,0 +1,217 @@
+"""
+Bash-style brace expansion
+Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
+License: MIT
+"""
+
+import re
+import string
+from itertools import chain, product
+from typing import Iterable, Iterator, Optional
+
+__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
+
+
+class UnbalancedBracesError(ValueError):
+ pass
+
+
+alphabet = string.ascii_uppercase + string.ascii_lowercase
+
+int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
+char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
+escape_re = re.compile(r"\\(.)")
+
+
+def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
+ """braceexpand(pattern) -> iterator over generated strings
+
+ Returns an iterator over the strings resulting from brace expansion
+ of pattern. This function implements Brace Expansion as described in
+ bash(1), with the following limitations:
+
+ * A pattern containing unbalanced braces will raise an
+ UnbalancedBracesError exception. In bash, unbalanced braces will either
+ be partly expanded or ignored.
+
+ * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
+ include the characters '[]^_`' between 'Z' and 'a'.
+
+ When escape is True (the default), characters in pattern can be
+ prefixed with a backslash to cause them not to be interpreted as
+ special characters for brace expansion (such as '{', '}', ',').
+ To pass through a a literal backslash, double it ('\\\\').
+
+ When escape is False, backslashes in pattern have no special
+ meaning and will be preserved in the output.
+
+ Examples:
+
+ >>> from braceexpand import braceexpand
+
+ # Integer range
+ >>> list(braceexpand('item{1..3}'))
+ ['item1', 'item2', 'item3']
+
+ # Character range
+ >>> list(braceexpand('{a..c}'))
+ ['a', 'b', 'c']
+
+ # Sequence
+ >>> list(braceexpand('index.html{,.backup}'))
+ ['index.html', 'index.html.backup']
+
+ # Nested patterns
+ >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
+ ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
+
+ # Prefixing an integer with zero causes all numbers to be padded to
+ # the same width.
+ >>> list(braceexpand('{07..10}'))
+ ['07', '08', '09', '10']
+
+ # An optional increment can be specified for ranges.
+ >>> list(braceexpand('{a..g..2}'))
+ ['a', 'c', 'e', 'g']
+
+ # Ranges can go in both directions.
+ >>> list(braceexpand('{4..1}'))
+ ['4', '3', '2', '1']
+
+ # Numbers can be negative
+ >>> list(braceexpand('{2..-1}'))
+ ['2', '1', '0', '-1']
+
+ # Unbalanced braces raise an exception.
+ >>> list(braceexpand('{1{2,3}'))
+ Traceback (most recent call last):
+ ...
+ UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
+
+ # By default, the backslash is the escape character.
+ >>> list(braceexpand(r'{1\\{2,3}'))
+ ['1{2', '3']
+
+ # Setting 'escape' to False disables backslash escaping.
+ >>> list(braceexpand(r'\\{1,2}', escape=False))
+ ['\\\\1', '\\\\2']
+
+ """
+ return (
+ escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
+ )
+
+
+def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'pattern:', pattern
+ while pos < len(pattern):
+ if escape and pattern[pos] == "\\":
+ pos += 2
+ continue
+ elif pattern[pos] == "{":
+ if bracketdepth == 0 and pos > start:
+ # print 'literal:', pattern[start:pos]
+ items.append([pattern[start:pos]])
+ start = pos
+ bracketdepth += 1
+ elif pattern[pos] == "}":
+ bracketdepth -= 1
+ if bracketdepth == 0:
+ # print 'expression:', pattern[start+1:pos]
+ expr = pattern[start + 1 : pos]
+ item = parse_expression(expr, escape)
+ if item is None: # not a range or sequence
+ items.extend([["{"], parse_pattern(expr, escape), ["}"]])
+ else:
+ items.append(item)
+ start = pos + 1 # skip the closing brace
+ pos += 1
+
+ if bracketdepth != 0: # unbalanced braces
+ raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
+
+ if start < pos:
+ items.append([pattern[start:]])
+
+ return ("".join(item) for item in product(*items))
+
+
+def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
+ int_range_match = int_range_re.match(expr)
+ if int_range_match:
+ return make_int_range(*int_range_match.groups())
+
+ char_range_match = char_range_re.match(expr)
+ if char_range_match:
+ return make_char_range(*char_range_match.groups())
+
+ return parse_sequence(expr, escape)
+
+
+def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
+ # sequence -> chain(*sequence_items)
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'sequence:', seq
+ while pos < len(seq):
+ if escape and seq[pos] == "\\":
+ pos += 2
+ continue
+ elif seq[pos] == "{":
+ bracketdepth += 1
+ elif seq[pos] == "}":
+ bracketdepth -= 1
+ elif seq[pos] == "," and bracketdepth == 0:
+ items.append(parse_pattern(seq[start:pos], escape))
+ start = pos + 1 # skip the comma
+ pos += 1
+
+ if bracketdepth != 0:
+ raise UnbalancedBracesError
+ if not items:
+ return None
+
+ # part after the last comma (may be the empty string)
+ items.append(parse_pattern(seq[start:], escape))
+ return chain(*items)
+
+
+def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
+ if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
+ padding = max(len(left), len(right))
+ else:
+ padding = 0
+ step = (int(incr) or 1) if incr else 1
+ start = int(left)
+ end = int(right)
+ r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
+ fmt = "%0{}d".format(padding)
+ return (fmt % i for i in r)
+
+
+def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
+ step = (int(incr) or 1) if incr else 1
+ start = alphabet.index(left)
+ end = alphabet.index(right)
+ if start < end:
+ return alphabet[start : end + 1 : step]
+ else:
+ end = end or -len(alphabet)
+ return alphabet[start : end - 1 : -step]
+
+
+if __name__ == "__main__":
+ import doctest
+ import sys
+
+ failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
+ if failed:
+ sys.exit(1)
diff --git a/fish_speech/utils/context.py b/fish_speech/utils/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..f04a99290ab32f7fe5b60656075a2d03af8468d6
--- /dev/null
+++ b/fish_speech/utils/context.py
@@ -0,0 +1,13 @@
+from contextlib import nullcontext
+
+import torch
+
+
+def autocast_exclude_mps(
+ device_type: str, dtype: torch.dtype
+) -> nullcontext | torch.autocast:
+ return (
+ nullcontext()
+ if torch.backends.mps.is_available()
+ else torch.autocast(device_type, dtype)
+ )
diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..78c82640a963fa556657107729f7543d2e7c3510
--- /dev/null
+++ b/fish_speech/utils/file.py
@@ -0,0 +1,16 @@
+import os
+from pathlib import Path
+
+
+def get_latest_checkpoint(path: Path | str) -> Path | None:
+ # Find the latest checkpoint
+ ckpt_dir = Path(path)
+
+ if ckpt_dir.exists() is False:
+ return None
+
+ ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
+ if len(ckpts) == 0:
+ return None
+
+ return ckpts[-1]
diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6ee463924f588a35477937fbe3c3364043bdf3e
--- /dev/null
+++ b/fish_speech/utils/instantiators.py
@@ -0,0 +1,50 @@
+from typing import List
+
+import hydra
+from omegaconf import DictConfig
+from pytorch_lightning import Callback
+from pytorch_lightning.loggers import Logger
+
+from .logger import RankedLogger
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+ """Instantiates callbacks from config."""
+
+ callbacks: List[Callback] = []
+
+ if not callbacks_cfg:
+ log.warning("No callback configs found! Skipping..")
+ return callbacks
+
+ if not isinstance(callbacks_cfg, DictConfig):
+ raise TypeError("Callbacks config must be a DictConfig!")
+
+ for _, cb_conf in callbacks_cfg.items():
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+ """Instantiates loggers from config."""
+
+ logger: List[Logger] = []
+
+ if not logger_cfg:
+ log.warning("No logger configs found! Skipping...")
+ return logger
+
+ if not isinstance(logger_cfg, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, lg_conf in logger_cfg.items():
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ return logger
diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..94f94f738d1d87404354d086c30ef0ad9ab04cdc
--- /dev/null
+++ b/fish_speech/utils/logger.py
@@ -0,0 +1,55 @@
+import logging
+from typing import Mapping, Optional
+
+from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
+
+
+class RankedLogger(logging.LoggerAdapter):
+ """A multi-GPU-friendly python command line logger."""
+
+ def __init__(
+ self,
+ name: str = __name__,
+ rank_zero_only: bool = True,
+ extra: Optional[Mapping[str, object]] = None,
+ ) -> None:
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
+ with their rank prefixed in the log message.
+
+ :param name: The name of the logger. Default is ``__name__``.
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
+ """
+ logger = logging.getLogger(name)
+ super().__init__(logger=logger, extra=extra)
+ self.rank_zero_only = rank_zero_only
+
+ def log(
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
+ ) -> None:
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
+ occur on that rank/process.
+
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
+ :param msg: The message to log.
+ :param rank: The rank to log at.
+ :param args: Additional args to pass to the underlying logging function.
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
+ """
+ if self.isEnabledFor(level):
+ msg, kwargs = self.process(msg, kwargs)
+ current_rank = getattr(rank_zero_only, "rank", None)
+ if current_rank is None:
+ raise RuntimeError(
+ "The `rank_zero_only.rank` needs to be set before use"
+ )
+ msg = rank_prefixed_message(msg, current_rank)
+ if self.rank_zero_only:
+ if current_rank == 0:
+ self.logger.log(level, msg, *args, **kwargs)
+ else:
+ if rank is None:
+ self.logger.log(level, msg, *args, **kwargs)
+ elif current_rank == rank:
+ self.logger.log(level, msg, *args, **kwargs)
diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e3b0a2519e12845f09e5fbe86dfccbf5b345429
--- /dev/null
+++ b/fish_speech/utils/logging_utils.py
@@ -0,0 +1,48 @@
+from lightning.pytorch.utilities import rank_zero_only
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: dict) -> None:
+ """Controls which config parts are saved by lightning loggers.
+
+ Additionally saves:
+ - Number of model parameters
+ """
+
+ hparams = {}
+
+ cfg = object_dict["cfg"]
+ model = object_dict["model"]
+ trainer = object_dict["trainer"]
+
+ if not trainer.logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ hparams["model"] = cfg["model"]
+
+ # save number of model parameters
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["model/params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ hparams["data"] = cfg["data"]
+ hparams["trainer"] = cfg["trainer"]
+
+ hparams["callbacks"] = cfg.get("callbacks")
+ hparams["extras"] = cfg.get("extras")
+
+ hparams["task_name"] = cfg.get("task_name")
+ hparams["tags"] = cfg.get("tags")
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
+ hparams["seed"] = cfg.get("seed")
+
+ # send hparams to all loggers
+ for logger in trainer.loggers:
+ logger.log_hyperparams(hparams)
diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a465f54d610779766d51e3d1a020a3b1517fd1f
--- /dev/null
+++ b/fish_speech/utils/rich_utils.py
@@ -0,0 +1,100 @@
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from hydra.core.hydra_config import HydraConfig
+from lightning.pytorch.utilities import rank_zero_only
+from omegaconf import DictConfig, OmegaConf, open_dict
+from rich.prompt import Prompt
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def print_config_tree(
+ cfg: DictConfig,
+ print_order: Sequence[str] = (
+ "data",
+ "model",
+ "callbacks",
+ "logger",
+ "trainer",
+ "paths",
+ "extras",
+ ),
+ resolve: bool = False,
+ save_to_file: bool = False,
+) -> None:
+ """Prints content of DictConfig using Rich library and its tree structure.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
+ """ # noqa: E501
+
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ queue = []
+
+ # add fields from `print_order` to queue
+ for field in print_order:
+ (
+ queue.append(field)
+ if field in cfg
+ else log.warning(
+ f"Field '{field}' not found in config. "
+ + f"Skipping '{field}' config printing..."
+ )
+ )
+
+ # add all the other fields to queue (not specified in `print_order`)
+ for field in cfg:
+ if field not in queue:
+ queue.append(field)
+
+ # generate config tree from queue
+ for field in queue:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_group = cfg[field]
+ if isinstance(config_group, DictConfig):
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+ else:
+ branch_content = str(config_group)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ # print config tree
+ rich.print(tree)
+
+ # save config tree to file
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+ rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+ """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
+
+ if not cfg.get("tags"):
+ if "id" in HydraConfig().cfg.hydra.job:
+ raise ValueError("Specify tags before launching a multirun!")
+
+ log.warning("No tags provided in config. Prompting user to input tags...")
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+ tags = [t.strip() for t in tags.split(",") if t != ""]
+
+ with open_dict(cfg):
+ cfg.tags = tags
+
+ log.info(f"Tags: {cfg.tags}")
+
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+ rich.print(cfg.tags, file=file)
diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..01c3d7a2ab0f707ae92dbde0feb173927720c841
--- /dev/null
+++ b/fish_speech/utils/spectrogram.py
@@ -0,0 +1,122 @@
+import torch
+import torchaudio.functional as F
+from torch import Tensor, nn
+from torchaudio.transforms import MelScale
+
+
+class LinearSpectrogram(nn.Module):
+ def __init__(
+ self,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ center=False,
+ mode="pow2_sqrt",
+ ):
+ super().__init__()
+
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.mode = mode
+
+ self.register_buffer("window", torch.hann_window(win_length), persistent=False)
+
+ def forward(self, y: Tensor) -> Tensor:
+ if y.ndim == 3:
+ y = y.squeeze(1)
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (
+ (self.win_length - self.hop_length) // 2,
+ (self.win_length - self.hop_length + 1) // 2,
+ ),
+ mode="reflect",
+ ).squeeze(1)
+
+ spec = torch.stft(
+ y,
+ self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ center=self.center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+
+ spec = torch.view_as_real(spec)
+
+ if self.mode == "pow2_sqrt":
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+ return spec
+
+
+class LogMelSpectrogram(nn.Module):
+ def __init__(
+ self,
+ sample_rate=44100,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ n_mels=128,
+ center=False,
+ f_min=0.0,
+ f_max=None,
+ ):
+ super().__init__()
+
+ self.sample_rate = sample_rate
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.n_mels = n_mels
+ self.f_min = f_min
+ self.f_max = f_max or float(sample_rate // 2)
+
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
+
+ fb = F.melscale_fbanks(
+ n_freqs=self.n_fft // 2 + 1,
+ f_min=self.f_min,
+ f_max=self.f_max,
+ n_mels=self.n_mels,
+ sample_rate=self.sample_rate,
+ norm="slaney",
+ mel_scale="slaney",
+ )
+ self.register_buffer(
+ "fb",
+ fb,
+ persistent=False,
+ )
+
+ def compress(self, x: Tensor) -> Tensor:
+ return torch.log(torch.clamp(x, min=1e-5))
+
+ def decompress(self, x: Tensor) -> Tensor:
+ return torch.exp(x)
+
+ def apply_mel_scale(self, x: Tensor) -> Tensor:
+ return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
+
+ def forward(
+ self, x: Tensor, return_linear: bool = False, sample_rate: int = None
+ ) -> Tensor:
+ if sample_rate is not None and sample_rate != self.sample_rate:
+ x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
+
+ linear = self.spectrogram(x)
+ x = self.apply_mel_scale(linear)
+ x = self.compress(x)
+
+ if return_linear:
+ return x, self.compress(linear)
+
+ return x
diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a34bdcfedff76c333f50ed8be050d0dd5a8f98a
--- /dev/null
+++ b/fish_speech/utils/utils.py
@@ -0,0 +1,136 @@
+import random
+import warnings
+from importlib.util import find_spec
+from typing import Callable
+
+import numpy as np
+import torch
+from omegaconf import DictConfig
+
+from .logger import RankedLogger
+from .rich_utils import enforce_tags, print_config_tree
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def extras(cfg: DictConfig) -> None:
+ """Applies optional utilities before the task is started.
+
+ Utilities:
+ - Ignoring python warnings
+ - Setting tags from command line
+ - Rich config printing
+ """
+
+ # return if no `extras` config
+ if not cfg.get("extras"):
+ log.warning("Extras config not found! ")
+ return
+
+ # disable python warnings
+ if cfg.extras.get("ignore_warnings"):
+ log.info("Disabling python warnings! ")
+ warnings.filterwarnings("ignore")
+
+ # prompt user to input tags from command line if none are provided in the config
+ if cfg.extras.get("enforce_tags"):
+ log.info("Enforcing tags! ")
+ enforce_tags(cfg, save_to_file=True)
+
+ # pretty print config tree using Rich library
+ if cfg.extras.get("print_config"):
+ log.info("Printing config tree with Rich! ")
+ print_config_tree(cfg, resolve=True, save_to_file=True)
+
+
+def task_wrapper(task_func: Callable) -> Callable:
+ """Optional decorator that controls the failure behavior when executing the task function.
+
+ This wrapper can be used to:
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
+ - save the exception to a `.log` file
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
+ - etc. (adjust depending on your needs)
+
+ Example:
+ ```
+ @utils.task_wrapper
+ def train(cfg: DictConfig) -> Tuple[dict, dict]:
+
+ ...
+
+ return metric_dict, object_dict
+ ```
+ """ # noqa: E501
+
+ def wrap(cfg: DictConfig):
+ # execute the task
+ try:
+ metric_dict, object_dict = task_func(cfg=cfg)
+
+ # things to do if exception occurs
+ except Exception as ex:
+ # save exception to `.log` file
+ log.exception("")
+
+ # some hyperparameter combinations might be invalid or
+ # cause out-of-memory errors so when using hparam search
+ # plugins like Optuna, you might want to disable
+ # raising the below exception to avoid multirun failure
+ raise ex
+
+ # things to always do after either success or exception
+ finally:
+ # display output dir path in terminal
+ log.info(f"Output dir: {cfg.paths.run_dir}")
+
+ # always close wandb run (even if exception occurs so multirun won't fail)
+ if find_spec("wandb"): # check if wandb is installed
+ import wandb
+
+ if wandb.run:
+ log.info("Closing wandb!")
+ wandb.finish()
+
+ return metric_dict, object_dict
+
+ return wrap
+
+
+def get_metric_value(metric_dict: dict, metric_name: str) -> float:
+ """Safely retrieves value of the metric logged in LightningModule."""
+
+ if not metric_name:
+ log.info("Metric name is None! Skipping metric value retrieval...")
+ return None
+
+ if metric_name not in metric_dict:
+ raise Exception(
+ f"Metric value not found! \n"
+ "Make sure metric name logged in LightningModule is correct!\n"
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
+ )
+
+ metric_value = metric_dict[metric_name].item()
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
+
+ return metric_value
+
+
+def set_seed(seed: int):
+ if seed < 0:
+ seed = -seed
+ if seed > (1 << 31):
+ seed = 1 << 31
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ if torch.backends.cudnn.is_available():
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
diff --git a/fish_speech/webui/css/style.css b/fish_speech/webui/css/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..3c7a22ecc31881a65a76369b0fd889330a0874c7
--- /dev/null
+++ b/fish_speech/webui/css/style.css
@@ -0,0 +1,161 @@
+:root {
+ --my-200: #80eeee;
+ --my-50: #ecfdf5;
+ --water-width: 300px;
+ --water-heigh: 300px;
+}
+
+
+/* general styled components */
+.tools {
+ align-items: center;
+ justify-content: center;
+}
+
+.gradio-button {
+ max-width: 2.2em;
+ min-width: 2.2em !important;
+ height: 2.4em;
+ align-self: end;
+ line-height: 1em;
+ border-radius: 0.5em;
+
+}
+
+.gradio-button.secondary-down, .gradio-button.secondary-down:hover{
+ box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset;
+}
+
+/* replace original footer with ours */
+a{
+ font-weight: bold;
+ cursor: pointer;
+ color: #030C14 !important;
+}
+
+footer {
+ display: none !important;
+}
+
+#footer{
+ text-align: center;
+}
+
+#footer div{
+ display: inline-block;
+}
+
+#footer .versions{
+ font-size: 85%;
+ opacity: 0.85;
+}
+
+/*@keyframes moveBackground {*/
+/* 0% {*/
+/* background-position: 0 0;*/
+/* }*/
+/* 100% {*/
+/* background-position: -100px 100px;*/
+/* }*/
+/*}*/
+@keyframes moveJellyBackground {
+ 0% {
+ background-position: 0% 50%;
+ }
+ 50% {
+ background-position: 100% 50%;
+ }
+ 100% {
+ background-position: 0% 50%;
+ }
+}
+
+.gradio-container {
+ position: absolute;
+ z-index: 10;
+}
+
+
+.quan {
+ position: absolute;
+ bottom: 0;
+ width: var(--water-width);
+ height: var(--water-heigh);
+ border-radius: 0;
+ /*border: 3px solid rgb(246, 247, 248);*/
+ /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/
+ z-index: 0;
+
+}
+
+.quan:last-child {
+ margin-right: 0;
+}
+
+.shui {
+ position: absolute;
+ top: 0;
+ left: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgb(23, 106, 201);
+ border-radius: 0;
+ overflow: hidden;
+ z-index: 0;
+}
+
+.shui::after {
+
+ content: '';
+ position: absolute;
+ top: 20%;
+ left: 50%;
+ width: 150%;
+ height: 150%;
+ border-radius: 40%;
+ background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%);
+ animation: shi 5s linear infinite;
+}
+
+@keyframes shi {
+ 0% {
+ transform: translate(-50%, -65%) rotate(0deg);
+ }
+ 100% {
+ transform: translate(-50%, -65%) rotate(360deg);
+ }
+}
+
+.shui::before {
+ content: '';
+ position: absolute;
+ top: 20%;
+ left: 50%;
+ width: 150%;
+ height: 150%;
+ border-radius: 42%;
+ background-color: rgb(240, 228, 228, 0.2);
+ animation: xu 7s linear infinite;
+}
+
+@keyframes xu {
+ 0% {
+ transform: translate(-50%, -60%) rotate(0deg);
+ }
+ 100% {
+ transform: translate(-50%, -60%) rotate(360deg);
+ }
+}
+
+fieldset.data_src div.wrap label {
+ background: #f8bffee0 !important;
+}
+
+.scrollable-component {
+ max-height: 100px;
+ overflow-y: auto;
+}
+
+#file_accordion {
+ max-height: 220px !important;
+}
diff --git a/fish_speech/webui/html/footer.html b/fish_speech/webui/html/footer.html
new file mode 100644
index 0000000000000000000000000000000000000000..ac1745aa6f41f86a17e3d95564c2bf7a8d7bb615
--- /dev/null
+++ b/fish_speech/webui/html/footer.html
@@ -0,0 +1,11 @@
+
+
+
+{versions}
+
diff --git a/fish_speech/webui/js/animate.js b/fish_speech/webui/js/animate.js
new file mode 100644
index 0000000000000000000000000000000000000000..0637a541a8e704632a42b89bdf1471b26e7bb868
--- /dev/null
+++ b/fish_speech/webui/js/animate.js
@@ -0,0 +1,69 @@
+
+function createGradioAnimation() {
+ const params = new URLSearchParams(window.location.search);
+ if (!params.has('__theme')) {
+ params.set('__theme', 'light');
+ window.location.search = params.toString();
+ }
+
+ var gradioApp = document.querySelector('gradio-app');
+ if (gradioApp) {
+
+ document.documentElement.style.setProperty('--my-200', '#80eeee');
+ document.documentElement.style.setProperty('--my-50', '#ecfdf5');
+
+ // gradioApp.style.position = 'relative';
+ // gradioApp.style.backgroundSize = '200% 200%';
+ // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite';
+ // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)';
+ // gradioApp.style.display = 'flex';
+ // gradioApp.style.justifyContent = 'flex-start';
+ // gradioApp.style.flexWrap = 'nowrap';
+ // gradioApp.style.overflowX = 'auto';
+
+ // for (let i = 0; i < 6; i++) {
+ // var quan = document.createElement('div');
+ // quan.className = 'quan';
+ // gradioApp.insertBefore(quan, gradioApp.firstChild);
+ // quan.id = 'quan' + i.toString();
+ // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')';
+ // var quanContainer = document.querySelector('.quan');
+ // if (quanContainer) {
+ // var shui = document.createElement('div');
+ // shui.className = 'shui';
+ // quanContainer.insertBefore(shui, quanContainer.firstChild)
+ // }
+ // }
+ }
+
+ var container = document.createElement('div');
+ container.id = 'gradio-animation';
+ container.style.fontSize = '2em';
+ container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace';
+ container.style.fontWeight = 'bold';
+ container.style.textAlign = 'center';
+ container.style.marginBottom = '20px';
+
+ var text = 'Welcome to Fish-Speech!';
+ for (var i = 0; i < text.length; i++) {
+ (function(i){
+ setTimeout(function(){
+ var letter = document.createElement('span');
+ letter.style.opacity = '0';
+ letter.style.transition = 'opacity 0.5s';
+ letter.innerText = text[i];
+
+ container.appendChild(letter);
+
+ setTimeout(function() {
+ letter.style.opacity = '1';
+ }, 50);
+ }, i * 200);
+ })(i);
+ }
+
+ var gradioContainer = document.querySelector('.gradio-container');
+ gradioContainer.insertBefore(container, gradioContainer.firstChild);
+
+ return 'Animation created';
+}
diff --git a/fish_speech/webui/launch_utils.py b/fish_speech/webui/launch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..790c0e632ce55e099e5578d8824e94b1d1260d6e
--- /dev/null
+++ b/fish_speech/webui/launch_utils.py
@@ -0,0 +1,120 @@
+import importlib.util
+import os
+import subprocess
+import sys
+from functools import lru_cache
+from pathlib import Path
+from typing import Iterable
+
+import gradio as gr
+from gradio.themes.base import Base
+from gradio.themes.utils import colors, fonts, sizes
+
+GIT = (
+ (Path(os.environ.get("GIT_HOME", "")) / "git").resolve()
+ if sys.platform == "win32"
+ else "git"
+)
+GIT = str(GIT)
+
+
+def is_module_installed(module_name: str) -> bool:
+ spec = importlib.util.find_spec(module_name)
+ return spec is not None
+
+
+@lru_cache()
+def commit_hash():
+ try:
+ return subprocess.check_output(
+ [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8"
+ ).strip()
+ except Exception:
+ return ""
+
+
+def versions_html():
+ import torch
+
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
+ commit = commit_hash()
+ hash = commit.strip("'").split(" ")[0]
+
+ return f"""
+version: {hash}
+ •
+python: {python_version}
+ •
+torch: {getattr(torch, '__long_version__',torch.__version__)}
+ •
+gradio: {gr.__version__}
+ •
+author: fishaudio
+"""
+
+
+def version_check(commit):
+ try:
+ import requests
+
+ commits = requests.get(
+ "https://api.github.com/repos/fishaudio/fish-speech/branches/main"
+ ).json()
+ if commit != "" and commits["commit"]["sha"] != commit:
+ print("--------------------------------------------------------")
+ print("| You are not up to date with the most recent release. |")
+ print("| Consider running `git pull` to update. |")
+ print("--------------------------------------------------------")
+ elif commits["commit"]["sha"] == commit:
+ print("You are up to date with the most recent release.")
+ else:
+ print("Not a git clone, can't perform version check.")
+ except Exception as e:
+ print("version check failed", e)
+
+
+class Seafoam(Base):
+ def __init__(
+ self,
+ *,
+ primary_hue: colors.Color | str = colors.emerald,
+ secondary_hue: colors.Color | str = colors.blue,
+ neutral_hue: colors.Color | str = colors.blue,
+ spacing_size: sizes.Size | str = sizes.spacing_md,
+ radius_size: sizes.Size | str = sizes.radius_md,
+ text_size: sizes.Size | str = sizes.text_lg,
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
+ fonts.GoogleFont("Quicksand"),
+ "ui-sans-serif",
+ "sans-serif",
+ ),
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
+ fonts.GoogleFont("IBM Plex Mono"),
+ "ui-monospace",
+ "monospace",
+ ),
+ ):
+ super().__init__(
+ primary_hue=primary_hue,
+ secondary_hue=secondary_hue,
+ neutral_hue=neutral_hue,
+ spacing_size=spacing_size,
+ radius_size=radius_size,
+ text_size=text_size,
+ font=font,
+ font_mono=font_mono,
+ )
+ super().set(
+ button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
+ button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
+ button_primary_text_color="white",
+ button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
+ slider_color="*secondary_300",
+ slider_color_dark="*secondary_600",
+ block_title_text_weight="600",
+ block_border_width="3px",
+ block_shadow="*shadow_drop_lg",
+ # button_shadow="*shadow_drop_lg",
+ button_small_padding="0px",
+ button_large_padding="3px",
+ )
diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d52a35e2d5612dd286f48d9515cebe4bc93b3bf
--- /dev/null
+++ b/fish_speech/webui/manage.py
@@ -0,0 +1,1239 @@
+from __future__ import annotations
+
+import os
+
+os.environ["USE_LIBUV"] = "0"
+import datetime
+import html
+import json
+import platform
+import shutil
+import signal
+import subprocess
+import sys
+from pathlib import Path
+
+import gradio as gr
+import psutil
+import yaml
+from loguru import logger
+from tqdm import tqdm
+
+PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
+sys.path.insert(0, "")
+print(sys.path)
+cur_work_dir = Path(os.getcwd()).resolve()
+print("You are in ", str(cur_work_dir))
+
+from fish_speech.i18n import i18n
+from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
+
+config_path = cur_work_dir / "fish_speech" / "configs"
+vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
+llama_yml_path = config_path / "text2semantic_finetune.yaml"
+
+env = os.environ.copy()
+env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
+
+seafoam = Seafoam()
+
+
+def build_html_error_message(error):
+ return f"""
+
+ {html.escape(error)}
+
+ """
+
+
+def build_html_ok_message(msg):
+ return f"""
+
+ {html.escape(msg)}
+
+ """
+
+
+def build_html_href(link, desc, msg):
+ return f"""
+
+ {html.escape(msg)}
+ {desc}
+
+ """
+
+
+def load_data_in_raw(path):
+ with open(path, "r", encoding="utf-8") as file:
+ data = file.read()
+ return str(data)
+
+
+def kill_proc_tree(pid, including_parent=True):
+ try:
+ parent = psutil.Process(pid)
+ except psutil.NoSuchProcess:
+ # Process already terminated
+ return
+
+ children = parent.children(recursive=True)
+ for child in children:
+ try:
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
+ except OSError:
+ pass
+ if including_parent:
+ try:
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
+ except OSError:
+ pass
+
+
+system = platform.system()
+p_label = None
+p_infer = None
+p_tensorboard = None
+
+
+def kill_process(pid):
+ if system == "Windows":
+ cmd = "taskkill /t /f /pid %s" % pid
+ # os.system(cmd)
+ subprocess.run(cmd)
+ else:
+ kill_proc_tree(pid)
+
+
+def change_label(if_label):
+ global p_label
+ if if_label == True and p_label is None:
+ url = "http://localhost:3000"
+ remote_url = "https://text-labeler.pages.dev/"
+ try:
+ p_label = subprocess.Popen(
+ [
+ (
+ "asr-label-linux-x64"
+ if sys.platform == "linux"
+ else "asr-label-win-x64.exe"
+ )
+ ]
+ )
+ except FileNotFoundError:
+ logger.warning("asr-label execution not found!")
+
+ yield build_html_href(
+ link=remote_url,
+ desc=i18n("Optional online ver"),
+ msg=i18n("Opened labeler in browser"),
+ )
+
+ elif if_label == False and p_label is not None:
+ kill_process(p_label.pid)
+ p_label = None
+ yield build_html_ok_message("Nothing")
+
+
+def clean_infer_cache():
+ import tempfile
+
+ temp_dir = Path(tempfile.gettempdir())
+ gradio_dir = str(temp_dir / "gradio")
+ try:
+ shutil.rmtree(gradio_dir)
+ logger.info(f"Deleted cached audios: {gradio_dir}")
+ except PermissionError:
+ logger.info(f"Permission denied: Unable to delete {gradio_dir}")
+ except FileNotFoundError:
+ logger.info(f"{gradio_dir} was not found")
+ except Exception as e:
+ logger.info(f"An error occurred: {e}")
+
+
+def change_infer(
+ if_infer,
+ host,
+ port,
+ infer_decoder_model,
+ infer_decoder_config,
+ infer_llama_model,
+ infer_compile,
+):
+ global p_infer
+ if if_infer == True and p_infer == None:
+ env = os.environ.copy()
+
+ env["GRADIO_SERVER_NAME"] = host
+ env["GRADIO_SERVER_PORT"] = port
+ # 启动第二个进程
+ url = f"http://{host}:{port}"
+ yield build_html_ok_message(
+ i18n("Inferring interface is launched at {}").format(url)
+ )
+
+ clean_infer_cache()
+
+ p_infer = subprocess.Popen(
+ [
+ PYTHON,
+ "tools/run_webui.py",
+ "--decoder-checkpoint-path",
+ infer_decoder_model,
+ "--decoder-config-name",
+ infer_decoder_config,
+ "--llama-checkpoint-path",
+ infer_llama_model,
+ ]
+ + (["--compile"] if infer_compile == "Yes" else []),
+ env=env,
+ )
+
+ elif if_infer == False and p_infer is not None:
+ kill_process(p_infer.pid)
+ p_infer = None
+ yield build_html_error_message(i18n("Infer interface is closed"))
+
+
+js = load_data_in_raw("fish_speech/webui/js/animate.js")
+css = load_data_in_raw("fish_speech/webui/css/style.css")
+
+data_pre_output = (cur_work_dir / "data").resolve()
+default_model_output = (cur_work_dir / "results").resolve()
+default_filelist = data_pre_output / "detect.list"
+data_pre_output.mkdir(parents=True, exist_ok=True)
+
+items = []
+dict_items = {}
+
+
+def load_yaml_data_in_fact(yml_path):
+ with open(yml_path, "r", encoding="utf-8") as file:
+ yml = yaml.safe_load(file)
+ return yml
+
+
+def write_yaml_data_in_fact(yml, yml_path):
+ with open(yml_path, "w", encoding="utf-8") as file:
+ yaml.safe_dump(yml, file, allow_unicode=True)
+ return yml
+
+
+def generate_tree(directory, depth=0, max_depth=None, prefix=""):
+ if max_depth is not None and depth > max_depth:
+ return ""
+
+ tree_str = ""
+ files = []
+ directories = []
+ for item in os.listdir(directory):
+ if os.path.isdir(os.path.join(directory, item)):
+ directories.append(item)
+ else:
+ files.append(item)
+
+ entries = directories + files
+ for i, entry in enumerate(entries):
+ connector = "├── " if i < len(entries) - 1 else "└── "
+ tree_str += f"{prefix}{connector}{entry} "
+ if i < len(directories):
+ extension = "│ " if i < len(entries) - 1 else " "
+ tree_str += generate_tree(
+ os.path.join(directory, entry),
+ depth + 1,
+ max_depth,
+ prefix=prefix + extension,
+ )
+ return tree_str
+
+
+def new_explorer(data_path, max_depth):
+ return gr.Markdown(
+ elem_classes=["scrollable-component"],
+ value=generate_tree(data_path, max_depth=max_depth),
+ )
+
+
+def add_item(
+ folder: str,
+ method: str,
+ label_lang: str,
+ if_initial_prompt: bool,
+ initial_prompt: str | None,
+):
+ folder = folder.strip(" ").strip('"')
+
+ folder_path = Path(folder)
+
+ if folder and folder not in items and data_pre_output not in folder_path.parents:
+ if folder_path.is_dir():
+ items.append(folder)
+ dict_items[folder] = dict(
+ type="folder",
+ method=method,
+ label_lang=label_lang,
+ initial_prompt=initial_prompt if if_initial_prompt else None,
+ )
+ elif folder:
+ err = folder
+ return gr.Checkboxgroup(choices=items), build_html_error_message(
+ i18n("Invalid path: {}").format(err)
+ )
+
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
+ logger.info("After Adding: " + formatted_data)
+ gr.Info(formatted_data)
+ return gr.Checkboxgroup(choices=items), build_html_ok_message(
+ i18n("Added path successfully!")
+ )
+
+
+def remove_items(selected_items):
+ global items, dict_items
+ to_remove = [item for item in items if item in selected_items]
+ for item in to_remove:
+ del dict_items[item]
+ items = [item for item in items if item in dict_items.keys()]
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
+ logger.info(formatted_data)
+ gr.Warning("After Removing: " + formatted_data)
+ return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
+ i18n("Removed path successfully!")
+ )
+
+
+def show_selected(options):
+ selected_options = ", ".join(options)
+
+ if options:
+ return i18n("Selected: {}").format(selected_options)
+ else:
+ return i18n("No selected options")
+
+
+from pydub import AudioSegment
+
+
+def convert_to_mono_in_place(audio_path: Path):
+ audio = AudioSegment.from_file(audio_path)
+ if audio.channels > 1:
+ mono_audio = audio.set_channels(1)
+ mono_audio.export(audio_path, format=audio_path.suffix[1:])
+ logger.info(f"Convert {audio_path} successfully")
+
+
+def list_copy(list_file_path, method):
+ wav_root = data_pre_output
+ lst = []
+ with list_file_path.open("r", encoding="utf-8") as file:
+ for line in tqdm(file, desc="Processing audio/transcript"):
+ wav_path, speaker_name, language, text = line.strip().split("|")
+ original_wav_path = Path(wav_path)
+ target_wav_path = (
+ wav_root / original_wav_path.parent.name / original_wav_path.name
+ )
+ lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
+ if target_wav_path.is_file():
+ continue
+ target_wav_path.parent.mkdir(parents=True, exist_ok=True)
+ if method == i18n("Copy"):
+ shutil.copy(original_wav_path, target_wav_path)
+ else:
+ shutil.move(original_wav_path, target_wav_path.parent)
+ convert_to_mono_in_place(target_wav_path)
+ original_lab_path = original_wav_path.with_suffix(".lab")
+ target_lab_path = (
+ wav_root
+ / original_wav_path.parent.name
+ / original_wav_path.with_suffix(".lab").name
+ )
+ if target_lab_path.is_file():
+ continue
+ if method == i18n("Copy"):
+ shutil.copy(original_lab_path, target_lab_path)
+ else:
+ shutil.move(original_lab_path, target_lab_path.parent)
+
+ if method == i18n("Move"):
+ with list_file_path.open("w", encoding="utf-8") as file:
+ file.writelines("\n".join(lst))
+
+ del lst
+ return build_html_ok_message(i18n("Use filelist"))
+
+
+def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
+ global dict_items
+ data_path = Path(data_path)
+ gr.Warning("Pre-processing begins...")
+ for item, content in dict_items.items():
+ item_path = Path(item)
+ tar_path = data_path / item_path.name
+
+ if content["type"] == "folder" and item_path.is_dir():
+ if content["method"] == i18n("Copy"):
+ os.makedirs(tar_path, exist_ok=True)
+ shutil.copytree(
+ src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
+ )
+ elif not tar_path.is_dir():
+ shutil.move(src=str(item_path), dst=str(tar_path))
+
+ for suf in ["wav", "flac", "mp3"]:
+ for audio_path in tar_path.glob(f"**/*.{suf}"):
+ convert_to_mono_in_place(audio_path)
+
+ cur_lang = content["label_lang"]
+ initial_prompt = content["initial_prompt"]
+
+ transcribe_cmd = [
+ PYTHON,
+ "tools/whisper_asr.py",
+ "--model-size",
+ label_model,
+ "--device",
+ label_device,
+ "--audio-dir",
+ tar_path,
+ "--save-dir",
+ tar_path,
+ "--language",
+ cur_lang,
+ ]
+
+ if initial_prompt is not None:
+ transcribe_cmd += ["--initial-prompt", initial_prompt]
+
+ if cur_lang != "IGNORE":
+ try:
+ gr.Warning("Begin To Transcribe")
+ subprocess.run(
+ transcribe_cmd,
+ env=env,
+ )
+ except Exception:
+ print("Transcription error occurred")
+
+ elif content["type"] == "file" and item_path.is_file():
+ list_copy(item_path, content["method"])
+
+ return build_html_ok_message(i18n("Move files successfully")), new_explorer(
+ data_path, max_depth=max_depth
+ )
+
+
+def generate_folder_name():
+ now = datetime.datetime.now()
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
+ return folder_name
+
+
+def train_process(
+ data_path: str,
+ option: str,
+ # llama config
+ llama_ckpt,
+ llama_base_config,
+ llama_lr,
+ llama_maxsteps,
+ llama_data_num_workers,
+ llama_data_batch_size,
+ llama_data_max_length,
+ llama_precision,
+ llama_check_interval,
+ llama_grad_batches,
+ llama_use_speaker,
+ llama_use_lora,
+):
+
+ backend = "nccl" if sys.platform == "linux" else "gloo"
+
+ new_project = generate_folder_name()
+ print("New Project Name: ", new_project)
+
+ if option == "VQGAN":
+ msg = "Skipped VQGAN Training."
+ gr.Warning(msg)
+ logger.info(msg)
+
+ if option == "LLAMA":
+ msg = "LLAMA Training begins..."
+ gr.Warning(msg)
+ logger.info(msg)
+ subprocess.run(
+ [
+ PYTHON,
+ "tools/vqgan/extract_vq.py",
+ str(data_pre_output),
+ "--num-workers",
+ "1",
+ "--batch-size",
+ "16",
+ "--config-name",
+ "firefly_gan_vq",
+ "--checkpoint-path",
+ "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ ]
+ )
+
+ subprocess.run(
+ [
+ PYTHON,
+ "tools/llama/build_dataset.py",
+ "--input",
+ str(data_pre_output),
+ "--text-extension",
+ ".lab",
+ "--num-workers",
+ "16",
+ ]
+ )
+ ckpt_path = "checkpoints/fish-speech-1.4/model.pth"
+ lora_prefix = "lora_" if llama_use_lora else ""
+ llama_name = lora_prefix + "text2semantic_" + new_project
+ latest = next(
+ iter(
+ sorted(
+ [
+ str(p.relative_to("results"))
+ for p in Path("results").glob(lora_prefix + "text2sem*/")
+ ],
+ reverse=True,
+ )
+ ),
+ llama_name,
+ )
+ project = (
+ llama_name
+ if llama_ckpt == i18n("new")
+ else (
+ latest
+ if llama_ckpt == i18n("latest")
+ else Path(llama_ckpt).relative_to("results")
+ )
+ )
+ logger.info(project)
+
+ if llama_check_interval > llama_maxsteps:
+ llama_check_interval = llama_maxsteps
+
+ train_cmd = [
+ PYTHON,
+ "fish_speech/train.py",
+ "--config-name",
+ "text2semantic_finetune",
+ f"project={project}",
+ f"trainer.strategy.process_group_backend={backend}",
+ f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
+ f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
+ f"model.optimizer.lr={llama_lr}",
+ f"trainer.max_steps={llama_maxsteps}",
+ f"data.num_workers={llama_data_num_workers}",
+ f"data.batch_size={llama_data_batch_size}",
+ f"max_length={llama_data_max_length}",
+ f"trainer.precision={llama_precision}",
+ f"trainer.val_check_interval={llama_check_interval}",
+ f"trainer.accumulate_grad_batches={llama_grad_batches}",
+ f"train_dataset.interactive_prob={llama_use_speaker}",
+ ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
+ logger.info(train_cmd)
+ subprocess.run(train_cmd)
+
+ return build_html_ok_message(i18n("Training stopped"))
+
+
+def tensorboard_process(
+ if_tensorboard: bool,
+ tensorboard_dir: str,
+ host: str,
+ port: str,
+):
+ global p_tensorboard
+ if if_tensorboard == True and p_tensorboard == None:
+ url = f"http://{host}:{port}"
+ yield build_html_ok_message(
+ i18n("Tensorboard interface is launched at {}").format(url)
+ )
+ prefix = ["tensorboard"]
+ if Path("fishenv").exists():
+ prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
+
+ p_tensorboard = subprocess.Popen(
+ prefix
+ + [
+ "--logdir",
+ tensorboard_dir,
+ "--host",
+ host,
+ "--port",
+ port,
+ "--reload_interval",
+ "120",
+ ]
+ )
+ elif if_tensorboard == False and p_tensorboard != None:
+ kill_process(p_tensorboard.pid)
+ p_tensorboard = None
+ yield build_html_error_message(i18n("Tensorboard interface is closed"))
+
+
+def fresh_tb_dir():
+ return gr.Dropdown(
+ choices=[str(p) for p in Path("results").glob("**/tensorboard/")]
+ )
+
+
+def list_decoder_models():
+ paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
+ if not paths:
+ logger.warning("No decoder model found")
+ return paths
+
+
+def list_llama_models():
+ choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
+ choices = sorted(choices, reverse=True)
+ if not choices:
+ logger.warning("No LLaMA model found")
+ return choices
+
+
+def list_lora_llama_models():
+ choices = sorted(
+ [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
+ )
+ if not choices:
+ logger.warning("No LoRA LLaMA model found")
+ return choices
+
+
+def fresh_decoder_model():
+ return gr.Dropdown(choices=list_decoder_models())
+
+
+def fresh_llama_ckpt(llama_use_lora):
+ return gr.Dropdown(
+ choices=[i18n("latest"), i18n("new")]
+ + (
+ [str(p) for p in Path("results").glob("text2sem*/")]
+ if not llama_use_lora
+ else [str(p) for p in Path("results").glob("lora_*/")]
+ )
+ )
+
+
+def fresh_llama_model():
+ return gr.Dropdown(choices=list_llama_models())
+
+
+def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
+ if (
+ lora_weight is None
+ or not Path(lora_weight).exists()
+ or not Path(llama_weight).exists()
+ ):
+ return build_html_error_message(
+ i18n(
+ "Path error, please check the model file exists in the corresponding path"
+ )
+ )
+ gr.Warning("Merging begins...")
+ merge_cmd = [
+ PYTHON,
+ "tools/llama/merge_lora.py",
+ "--lora-config",
+ "r_8_alpha_16",
+ "--lora-weight",
+ lora_weight,
+ "--output",
+ llama_lora_output + "_" + generate_folder_name(),
+ ]
+ logger.info(merge_cmd)
+ subprocess.run(merge_cmd)
+ return build_html_ok_message(i18n("Merge successfully"))
+
+
+def llama_quantify(llama_weight, quantify_mode):
+ if llama_weight is None or not Path(llama_weight).exists():
+ return build_html_error_message(
+ i18n(
+ "Path error, please check the model file exists in the corresponding path"
+ )
+ )
+
+ gr.Warning("Quantifying begins...")
+
+ now = generate_folder_name()
+ quantify_cmd = [
+ PYTHON,
+ "tools/llama/quantize.py",
+ "--checkpoint-path",
+ llama_weight,
+ "--mode",
+ quantify_mode,
+ "--timestamp",
+ now,
+ ]
+ logger.info(quantify_cmd)
+ subprocess.run(quantify_cmd)
+ if quantify_mode == "int8":
+ quantize_path = str(
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
+ )
+ else:
+ quantize_path = str(
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
+ )
+ return build_html_ok_message(
+ i18n("Quantify successfully") + f"Path: {quantize_path}"
+ )
+
+
+init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
+init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
+
+with gr.Blocks(
+ head="",
+ js=js,
+ theme=seafoam,
+ analytics_enabled=False,
+ title="Fish Speech",
+) as demo:
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")):
+ with gr.Row():
+ textbox = gr.Textbox(
+ label="\U0000270F "
+ + i18n("Input Audio & Source Path for Transcription"),
+ info=i18n("Speaker is identified by the folder name"),
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ with gr.Column():
+ output_radio = gr.Radio(
+ label="\U0001F4C1 "
+ + i18n("Select source file processing method"),
+ choices=[i18n("Copy"), i18n("Move")],
+ value=i18n("Copy"),
+ interactive=True,
+ )
+ with gr.Column():
+ error = gr.HTML(label=i18n("Error Message"))
+ if_label = gr.Checkbox(
+ label=i18n("Open Labeler WebUI"), scale=0, show_label=True
+ )
+
+ with gr.Row():
+ label_device = gr.Dropdown(
+ label=i18n("Labeling Device"),
+ info=i18n(
+ "It is recommended to use CUDA, if you have low configuration, use CPU"
+ ),
+ choices=["cpu", "cuda"],
+ value="cuda",
+ interactive=True,
+ )
+ label_model = gr.Dropdown(
+ label=i18n("Whisper Model"),
+ info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
+ choices=["large-v3", "medium"],
+ value="large-v3",
+ interactive=True,
+ )
+ label_radio = gr.Dropdown(
+ label=i18n("Optional Label Language"),
+ info=i18n(
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
+ ),
+ choices=[
+ (i18n("Chinese"), "zh"),
+ (i18n("English"), "en"),
+ (i18n("Japanese"), "ja"),
+ (i18n("Disabled"), "IGNORE"),
+ (i18n("auto"), "auto"),
+ ],
+ value="IGNORE",
+ interactive=True,
+ )
+
+ with gr.Row():
+ if_initial_prompt = gr.Checkbox(
+ value=False,
+ label=i18n("Enable Initial Prompt"),
+ min_width=120,
+ scale=0,
+ )
+ initial_prompt = gr.Textbox(
+ label=i18n("Initial Prompt"),
+ info=i18n(
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
+ ),
+ placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
+ interactive=False,
+ )
+
+ with gr.Row():
+ add_button = gr.Button(
+ "\U000027A1 " + i18n("Add to Processing Area"),
+ variant="primary",
+ )
+ remove_button = gr.Button(
+ "\U000026D4 " + i18n("Remove Selected Data")
+ )
+
+ with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
+ with gr.Row():
+ model_type_radio = gr.Radio(
+ label=i18n(
+ "Select the model to be trained (Depending on the Tab page you are on)"
+ ),
+ interactive=False,
+ choices=["VQGAN", "LLAMA"],
+ value="VQGAN",
+ )
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
+ gr.HTML("You don't need to train this model!")
+
+ with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
+ with gr.Row(equal_height=False):
+ llama_use_lora = gr.Checkbox(
+ label=i18n("Use LoRA"),
+ info=i18n(
+ "Use LoRA can save GPU memory, but may reduce the quality of the model"
+ ),
+ value=True,
+ interactive=True,
+ )
+ llama_ckpt = gr.Dropdown(
+ label=i18n("Select LLAMA ckpt"),
+ choices=[i18n("latest"), i18n("new")]
+ + [
+ str(p)
+ for p in Path("results").glob("text2sem*/")
+ ]
+ + [str(p) for p in Path("results").glob("lora*/")],
+ value=i18n("latest"),
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lr_slider = gr.Slider(
+ label=i18n("Initial Learning Rate"),
+ info=i18n(
+ "lr smaller -> usually train slower but more stable"
+ ),
+ interactive=True,
+ minimum=1e-5,
+ maximum=1e-4,
+ step=1e-5,
+ value=5e-5,
+ )
+ llama_maxsteps_slider = gr.Slider(
+ label=i18n("Maximum Training Steps"),
+ info=i18n(
+ "recommend: max_steps = num_audios // batch_size * (2 to 5)"
+ ),
+ interactive=True,
+ minimum=1,
+ maximum=10000,
+ step=1,
+ value=50,
+ )
+ with gr.Row(equal_height=False):
+ llama_base_config = gr.Dropdown(
+ label=i18n("Model Size"),
+ choices=[
+ "text2semantic_finetune",
+ ],
+ value="text2semantic_finetune",
+ )
+ llama_data_num_workers_slider = gr.Slider(
+ label=i18n("Number of Workers"),
+ minimum=1,
+ maximum=32,
+ step=1,
+ value=4,
+ )
+ with gr.Row(equal_height=False):
+ llama_data_batch_size_slider = gr.Slider(
+ label=i18n("Batch Size"),
+ interactive=True,
+ minimum=1,
+ maximum=32,
+ step=1,
+ value=2,
+ )
+ llama_data_max_length_slider = gr.Slider(
+ label=i18n("Maximum Length per Sample"),
+ interactive=True,
+ minimum=1024,
+ maximum=4096,
+ step=128,
+ value=2048,
+ )
+ with gr.Row(equal_height=False):
+ llama_precision_dropdown = gr.Dropdown(
+ label=i18n("Precision"),
+ info=i18n(
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
+ ),
+ interactive=True,
+ choices=["32", "bf16-true", "16-mixed"],
+ value="bf16-true",
+ )
+ llama_check_interval_slider = gr.Slider(
+ label=i18n("Save model every n steps"),
+ info=i18n(
+ "make sure that it's not greater than max_steps"
+ ),
+ interactive=True,
+ minimum=1,
+ maximum=1000,
+ step=1,
+ value=50,
+ )
+ with gr.Row(equal_height=False):
+ llama_grad_batches = gr.Slider(
+ label=i18n("Accumulate Gradient Batches"),
+ interactive=True,
+ minimum=1,
+ maximum=20,
+ step=1,
+ value=init_llama_yml["trainer"][
+ "accumulate_grad_batches"
+ ],
+ )
+ llama_use_speaker = gr.Slider(
+ label=i18n(
+ "Probability of applying Speaker Condition"
+ ),
+ interactive=True,
+ minimum=0.1,
+ maximum=1.0,
+ step=0.05,
+ value=init_llama_yml["train_dataset"][
+ "interactive_prob"
+ ],
+ )
+
+ with gr.Tab(label=i18n("Merge LoRA"), id=4):
+ with gr.Row(equal_height=False):
+ llama_weight = gr.Dropdown(
+ label=i18n("Base LLAMA Model"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ "checkpoints/fish-speech-1.4/model.pth",
+ ],
+ value="checkpoints/fish-speech-1.4/model.pth",
+ allow_custom_value=True,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ lora_weight = gr.Dropdown(
+ label=i18n("LoRA Model to be merged"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ str(p)
+ for p in Path("results").glob("lora*/**/*.ckpt")
+ ],
+ allow_custom_value=True,
+ interactive=True,
+ )
+ lora_llama_config = gr.Dropdown(
+ label=i18n("LLAMA Model Config"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ "text2semantic_finetune",
+ ],
+ value="text2semantic_finetune",
+ allow_custom_value=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lora_output = gr.Dropdown(
+ label=i18n("Output Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ value="checkpoints/merged",
+ choices=["checkpoints/merged"],
+ allow_custom_value=True,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lora_merge_btn = gr.Button(
+ value=i18n("Merge"), variant="primary"
+ )
+
+ with gr.Tab(label=i18n("Model Quantization"), id=5):
+ with gr.Row(equal_height=False):
+ llama_weight_to_quantify = gr.Dropdown(
+ label=i18n("Base LLAMA Model"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=list_llama_models(),
+ value="checkpoints/fish-speech-1.4",
+ allow_custom_value=True,
+ interactive=True,
+ )
+ quantify_mode = gr.Dropdown(
+ label=i18n("Post-quantification Precision"),
+ info=i18n(
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
+ ),
+ choices=["int8", "int4"],
+ value="int8",
+ allow_custom_value=False,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_quantify_btn = gr.Button(
+ value=i18n("Quantify"), variant="primary"
+ )
+
+ with gr.Tab(label="Tensorboard", id=6):
+ with gr.Row(equal_height=False):
+ tb_host = gr.Textbox(
+ label=i18n("Tensorboard Host"), value="127.0.0.1"
+ )
+ tb_port = gr.Textbox(
+ label=i18n("Tensorboard Port"), value="11451"
+ )
+ with gr.Row(equal_height=False):
+ tb_dir = gr.Dropdown(
+ label=i18n("Tensorboard Log Path"),
+ allow_custom_value=True,
+ choices=[
+ str(p)
+ for p in Path("results").glob("**/tensorboard/")
+ ],
+ )
+ with gr.Row(equal_height=False):
+ if_tb = gr.Checkbox(
+ label=i18n("Open Tensorboard"),
+ )
+
+ with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")):
+ with gr.Column():
+ with gr.Row():
+ with gr.Accordion(
+ label="\U0001F5A5 "
+ + i18n("Inference Server Configuration"),
+ open=False,
+ ):
+ with gr.Row():
+ infer_host_textbox = gr.Textbox(
+ label=i18n("WebUI Host"), value="127.0.0.1"
+ )
+ infer_port_textbox = gr.Textbox(
+ label=i18n("WebUI Port"), value="7862"
+ )
+ with gr.Row():
+ infer_decoder_model = gr.Dropdown(
+ label=i18n("Decoder Model Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=list_decoder_models(),
+ value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ allow_custom_value=True,
+ )
+ infer_decoder_config = gr.Dropdown(
+ label=i18n("Decoder Model Config"),
+ info=i18n("Changing with the Model Path"),
+ value="firefly_gan_vq",
+ choices=[
+ "firefly_gan_vq",
+ ],
+ allow_custom_value=True,
+ )
+ with gr.Row():
+ infer_llama_model = gr.Dropdown(
+ label=i18n("LLAMA Model Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ value="checkpoints/fish-speech-1.4",
+ choices=list_llama_models(),
+ allow_custom_value=True,
+ )
+
+ with gr.Row():
+ infer_compile = gr.Radio(
+ label=i18n("Compile Model"),
+ info=i18n(
+ "Compile the model can significantly reduce the inference time, but will increase cold start time"
+ ),
+ choices=["Yes", "No"],
+ value=(
+ "Yes" if (sys.platform == "linux") else "No"
+ ),
+ interactive=is_module_installed("triton"),
+ )
+
+ with gr.Row():
+ infer_checkbox = gr.Checkbox(
+ label=i18n("Open Inference Server")
+ )
+ infer_error = gr.HTML(label=i18n("Inference Server Error"))
+
+ with gr.Column():
+ train_error = gr.HTML(label=i18n("Training Error"))
+ checkbox_group = gr.CheckboxGroup(
+ label="\U0001F4CA " + i18n("Data Source"),
+ info=i18n(
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list."
+ ),
+ elem_classes=["data_src"],
+ )
+ train_box = gr.Textbox(
+ label=i18n("Data Preprocessing Path"),
+ value=str(data_pre_output),
+ interactive=False,
+ )
+ model_box = gr.Textbox(
+ label="\U0001F4BE " + i18n("Model Output Path"),
+ value=str(default_model_output),
+ interactive=False,
+ )
+
+ with gr.Accordion(
+ i18n(
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)"
+ ),
+ elem_classes=["scrollable-component"],
+ elem_id="file_accordion",
+ ):
+ tree_slider = gr.Slider(
+ minimum=0,
+ maximum=3,
+ value=0,
+ step=1,
+ show_label=False,
+ container=False,
+ )
+ file_markdown = new_explorer(str(data_pre_output), 0)
+ with gr.Row(equal_height=False):
+ admit_btn = gr.Button(
+ "\U00002705 " + i18n("File Preprocessing"),
+ variant="primary",
+ )
+ fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
+ help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
+ train_btn = gr.Button(i18n("Start Training"), variant="primary")
+
+ footer = load_data_in_raw("fish_speech/webui/html/footer.html")
+ footer = footer.format(
+ versions=versions_html(),
+ api_docs="https://speech.fish.audio/inference/#http-api",
+ )
+ gr.HTML(footer, elem_id="footer")
+ vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
+ llama_page.select(lambda: "LLAMA", None, model_type_radio)
+ add_button.click(
+ fn=add_item,
+ inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
+ outputs=[checkbox_group, error],
+ )
+ remove_button.click(
+ fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
+ )
+ checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
+ help_button.click(
+ fn=None,
+ js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
+ 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
+ )
+ if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
+ if_initial_prompt.change(
+ fn=lambda x: gr.Textbox(value="", interactive=x),
+ inputs=[if_initial_prompt],
+ outputs=[initial_prompt],
+ )
+ train_btn.click(
+ fn=train_process,
+ inputs=[
+ train_box,
+ model_type_radio,
+ # llama config
+ llama_ckpt,
+ llama_base_config,
+ llama_lr_slider,
+ llama_maxsteps_slider,
+ llama_data_num_workers_slider,
+ llama_data_batch_size_slider,
+ llama_data_max_length_slider,
+ llama_precision_dropdown,
+ llama_check_interval_slider,
+ llama_grad_batches,
+ llama_use_speaker,
+ llama_use_lora,
+ ],
+ outputs=[train_error],
+ )
+ if_tb.change(
+ fn=tensorboard_process,
+ inputs=[if_tb, tb_dir, tb_host, tb_port],
+ outputs=[train_error],
+ )
+ tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
+ infer_decoder_model.change(
+ fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
+ )
+ infer_llama_model.change(
+ fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
+ )
+ llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
+ admit_btn.click(
+ fn=check_files,
+ inputs=[train_box, tree_slider, label_model, label_device],
+ outputs=[error, file_markdown],
+ )
+ fresh_btn.click(
+ fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
+ )
+ llama_use_lora.change(
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
+ )
+ llama_ckpt.change(
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
+ )
+ lora_weight.change(
+ fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
+ inputs=[],
+ outputs=[lora_weight],
+ )
+ llama_lora_merge_btn.click(
+ fn=llama_lora_merge,
+ inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
+ outputs=[train_error],
+ )
+ llama_quantify_btn.click(
+ fn=llama_quantify,
+ inputs=[llama_weight_to_quantify, quantify_mode],
+ outputs=[train_error],
+ )
+ infer_checkbox.change(
+ fn=change_infer,
+ inputs=[
+ infer_checkbox,
+ infer_host_textbox,
+ infer_port_textbox,
+ infer_decoder_model,
+ infer_decoder_config,
+ infer_llama_model,
+ infer_compile,
+ ],
+ outputs=[infer_error],
+ )
+
+demo.launch(inbrowser=True)
diff --git a/inference.ipynb b/inference.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..3bd94ebe74260476febc4f90c0c160fefdfe5882
--- /dev/null
+++ b/inference.ipynb
@@ -0,0 +1,214 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Fish Speech"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### For Windows User / win用户"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "bat"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!chcp 65001"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### For Linux User / Linux 用户"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import locale\n",
+ "locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Prepare Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# For Chinese users, you probably want to use mirror to accelerate downloading\n",
+ "# !set HF_ENDPOINT=https://hf-mirror.com\n",
+ "# !export HF_ENDPOINT=https://hf-mirror.com \n",
+ "\n",
+ "!huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4/"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## WebUI Inference\n",
+ "\n",
+ "> You can use --compile to fuse CUDA kernels for faster inference (10x)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!python tools/run_webui.py \\\n",
+ " --llama-checkpoint-path checkpoints/fish-speech-1.4 \\\n",
+ " --decoder-checkpoint-path checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
+ " # --compile"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Break-down CLI Inference"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. Encode reference audio: / 从语音生成 prompt: \n",
+ "\n",
+ "You should get a `fake.npy` file.\n",
+ "\n",
+ "你应该能得到一个 `fake.npy` 文件."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "## Enter the path to the audio file here\n",
+ "src_audio = r\"D:\\PythonProject\\vo_hutao_draw_appear.wav\"\n",
+ "\n",
+ "!python tools/vqgan/inference.py \\\n",
+ " -i {src_audio} \\\n",
+ " --checkpoint-path \"checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
+ "\n",
+ "from IPython.display import Audio, display\n",
+ "audio = Audio(filename=\"fake.wav\")\n",
+ "display(audio)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. Generate semantic tokens from text: / 从文本生成语义 token:\n",
+ "\n",
+ "> This command will create a codes_N file in the working directory, where N is an integer starting from 0.\n",
+ "\n",
+ "> You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~300 tokens/second).\n",
+ "\n",
+ "> 该命令会在工作目录下创建 codes_N 文件, 其中 N 是从 0 开始的整数.\n",
+ "\n",
+ "> 您可以使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 tokens/秒 -> ~300 tokens/秒)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!python tools/llama/generate.py \\\n",
+ " --text \"hello world\" \\\n",
+ " --prompt-text \"The text corresponding to reference audio\" \\\n",
+ " --prompt-tokens \"fake.npy\" \\\n",
+ " --checkpoint-path \"checkpoints/fish-speech-1.4\" \\\n",
+ " --num-samples 2\n",
+ " # --compile"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. Generate speech from semantic tokens: / 从语义 token 生成人声:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!python tools/vqgan/inference.py \\\n",
+ " -i \"codes_0.npy\" \\\n",
+ " --checkpoint-path \"checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
+ "\n",
+ "from IPython.display import Audio, display\n",
+ "audio = Audio(filename=\"fake.wav\")\n",
+ "display(audio)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/install_env.bat b/install_env.bat
new file mode 100644
index 0000000000000000000000000000000000000000..744ddb42c70a84d342fc3a2e357d8d36f7d4c998
--- /dev/null
+++ b/install_env.bat
@@ -0,0 +1,180 @@
+@echo off
+chcp 65001
+
+set USE_MIRROR=true
+echo "USE_MIRROR: %USE_MIRROR%"
+setlocal enabledelayedexpansion
+
+cd /D "%~dp0"
+
+set PATH="%PATH%";%SystemRoot%\system32
+
+echo %PATH%
+
+
+echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && (
+ echo.
+ echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && (
+ goto end
+ )
+)
+
+
+set TMP=%CD%\fishenv
+set TEMP=%CD%\fishenv
+
+(call conda deactivate && call conda deactivate && call conda deactivate) 2>nul
+
+set INSTALL_DIR=%cd%\fishenv
+set CONDA_ROOT_PREFIX=%cd%\fishenv\conda
+set INSTALL_ENV_DIR=%cd%\fishenv\env
+set PIP_CMD=%cd%\fishenv\env\python -m pip
+set PYTHON_CMD=%cd%\fishenv\env\python
+set API_FLAG_PATH=%~dp0API_FLAGS.txt
+set MINICONDA_DOWNLOAD_URL=https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Windows-x86_64.exe
+if "!USE_MIRROR!" == "true" (
+ set MINICONDA_DOWNLOAD_URL=https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py310_23.3.1-0-Windows-x86_64.exe
+)
+set MINICONDA_CHECKSUM=307194e1f12bbeb52b083634e89cc67db4f7980bd542254b43d3309eaf7cb358
+set conda_exists=F
+
+call "%CONDA_ROOT_PREFIX%\_conda.exe" --version >nul 2>&1
+if "%ERRORLEVEL%" EQU "0" set conda_exists=T
+
+if "%conda_exists%" == "F" (
+ echo.
+ echo Downloading Miniconda...
+ mkdir "%INSTALL_DIR%" 2>nul
+ call curl -Lk "%MINICONDA_DOWNLOAD_URL%" > "%INSTALL_DIR%\miniconda_installer.exe"
+ if errorlevel 1 (
+ echo.
+ echo Failed to download miniconda.
+ goto end
+ )
+ for /f %%a in ('
+ certutil -hashfile "%INSTALL_DIR%\miniconda_installer.exe" sha256
+ ^| find /i /v " "
+ ^| find /i "%MINICONDA_CHECKSUM%"
+ ') do (
+ set "hash=%%a"
+ )
+ if not defined hash (
+ echo.
+ echo Miniconda hash mismatched!
+ del "%INSTALL_DIR%\miniconda_installer.exe"
+ goto end
+ ) else (
+ echo.
+ echo Miniconda hash matched successfully.
+ )
+ echo Downloaded "%CONDA_ROOT_PREFIX%"
+ start /wait "" "%INSTALL_DIR%\miniconda_installer.exe" /InstallationType=JustMe /NoShortcuts=1 /AddToPath=0 /RegisterPython=0 /NoRegistry=1 /S /D=%CONDA_ROOT_PREFIX%
+
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" --version
+ if errorlevel 1 (
+ echo.
+ echo Cannot install Miniconda.
+ goto end
+ ) else (
+ echo.
+ echo Miniconda Install success.
+ )
+
+ del "%INSTALL_DIR%\miniconda_installer.exe"
+)
+
+
+if not exist "%INSTALL_ENV_DIR%" (
+ echo.
+ echo Creating Conda Environment...
+ if "!USE_MIRROR!" == "true" (
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ python=3.10
+ ) else (
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" python=3.10
+ )
+
+ if errorlevel 1 (
+ echo.
+ echo Failed to Create Environment.
+ goto end
+ )
+)
+
+if not exist "%INSTALL_ENV_DIR%\python.exe" (
+ echo.
+ echo Conda Env does not exist.
+ goto end
+)
+
+set PYTHONNOUSERSITE=1
+set PYTHONPATH=
+set PYTHONHOME=
+set "CUDA_PATH=%INSTALL_ENV_DIR%"
+set "CUDA_HOME=%CUDA_PATH%"
+
+call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%"
+
+if errorlevel 1 (
+ echo.
+ echo Failed to activate Env.
+ goto end
+) else (
+ echo.
+ echo successfully create env.
+)
+
+set "HF_ENDPOINT=https://huggingface.co"
+set "no_proxy="
+if "%USE_MIRROR%"=="true" (
+ set "HF_ENDPOINT=https://hf-mirror.com"
+ set "no_proxy=localhost,127.0.0.1,0.0.0.0"
+)
+
+echo "HF_ENDPOINT: !HF_ENDPOINT!"
+echo "NO_PROXY: !no_proxy!"
+
+%PIP_CMD% install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+%PIP_CMD% install -e . --upgrade-strategy only-if-needed
+
+call :download_and_install "triton_windows-0.1.0-py3-none-any.whl" ^
+ "%HF_ENDPOINT%/datasets/SpicyqSama007/windows_compile/resolve/main/triton_windows-0.1.0-py3-none-any.whl?download=true" ^
+ "2cc998638180f37cf5025ab65e48c7f629aa5a369176cfa32177d2bd9aa26a0a"
+
+
+endlocal
+echo "Environment Check: Success."
+:end
+pause
+
+goto :EOF
+
+
+:download_and_install
+setlocal
+
+set "WHEEL_FILE=%1"
+set "URL=%2"
+set "CHKSUM=%3"
+
+:DOWNLOAD
+if not exist "%WHEEL_FILE%" (
+ call curl -Lk "%URL%" --output "%WHEEL_FILE%"
+)
+
+for /f "delims=" %%I in ("certutil -hashfile %WHEEL_FILE% SHA256 ^| find /i %CHKSUM%") do (
+ set "FILE_VALID=true"
+)
+
+if not defined FILE_VALID (
+ echo File checksum does not match, re-downloading...
+ del "%WHEEL_FILE%"
+ goto DOWNLOAD
+)
+
+echo "OK for %WHEEL_FILE%"
+%PIP_CMD% install "%WHEEL_FILE%" --no-warn-script-location
+del "%WHEEL_FILE%"
+
+endlocal
+goto :EOF
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ae2099fe85ba34e2848ad6a1d1bea498155908b6
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,144 @@
+site_name: Fish Speech
+site_description: Targeting SOTA TTS solutions.
+site_url: https://speech.fish.audio
+
+# Repository
+repo_name: fishaudio/fish-speech
+repo_url: https://github.com/fishaudio/fish-speech
+edit_uri: blob/main/docs
+
+# Copyright
+copyright: Copyright © 2023-2024 by Fish Audio
+
+theme:
+ name: material
+ favicon: assets/figs/logo-circle.png
+ language: en
+ features:
+ - content.action.edit
+ - content.action.view
+ - navigation.tracking
+ - navigation.footer
+ # - navigation.tabs
+ - search
+ - search.suggest
+ - search.highlight
+ - search.share
+ - content.code.copy
+ icon:
+ logo: fontawesome/solid/fish
+
+ palette:
+ # Palette toggle for automatic mode
+ - media: "(prefers-color-scheme)"
+ toggle:
+ icon: material/brightness-auto
+ name: Switch to light mode
+
+ # Palette toggle for light mode
+ - media: "(prefers-color-scheme: light)"
+ scheme: default
+ toggle:
+ icon: material/brightness-7
+ name: Switch to dark mode
+ primary: black
+ font:
+ code: Roboto Mono
+
+ # Palette toggle for dark mode
+ - media: "(prefers-color-scheme: dark)"
+ scheme: slate
+ toggle:
+ icon: material/brightness-4
+ name: Switch to light mode
+ primary: black
+ font:
+ code: Roboto Mono
+
+nav:
+ - Introduction: index.md
+ - Finetune: finetune.md
+ - Inference: inference.md
+ - Start Agent: start_agent.md
+ - Samples: samples.md
+
+# Plugins
+plugins:
+ - search:
+ separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])'
+ lang:
+ - en
+ - zh
+ - ja
+ - pt
+ - ko
+ - i18n:
+ docs_structure: folder
+ languages:
+ - locale: en
+ name: English
+ default: true
+ build: true
+ - locale: zh
+ name: 简体中文
+ build: true
+ nav:
+ - 介绍: zh/index.md
+ - 微调: zh/finetune.md
+ - 推理: zh/inference.md
+ - 启动Agent: zh/start_agent.md
+ - 例子: zh/samples.md
+ - locale: ja
+ name: 日本語
+ build: true
+ nav:
+ - Fish Speech の紹介: ja/index.md
+ - 微調整: ja/finetune.md
+ - 推論: ja/inference.md
+ - スタートエージェント: ja/start_agent.md
+ - サンプル: ja/samples.md
+ - locale: pt
+ name: Português (Brasil)
+ build: true
+ nav:
+ - Introdução: pt/index.md
+ - Ajuste Fino: pt/finetune.md
+ - Inferência: pt/inference.md
+ - Agente inicial: pt/start_agent.md
+ - Amostras: pt/samples.md
+ - locale: ko
+ name: 한국어
+ build: true
+ nav:
+ - 소개: ko/index.md
+ - 파인튜닝: ko/finetune.md
+ - 추론: ko/inference.md
+ - 샘플: ko/samples.md
+
+markdown_extensions:
+ - pymdownx.highlight:
+ anchor_linenums: true
+ line_spans: __span
+ pygments_lang_class: true
+ - pymdownx.inlinehilite
+ - pymdownx.snippets
+ - pymdownx.superfences
+ - admonition
+ - pymdownx.details
+ - pymdownx.superfences
+ - attr_list
+ - md_in_html
+ - pymdownx.superfences
+
+extra_css:
+ - stylesheets/extra.css
+
+extra:
+ social:
+ - icon: fontawesome/brands/discord
+ link: https://discord.gg/Es5qTB9BcN
+ - icon: fontawesome/brands/docker
+ link: https://hub.docker.com/r/fishaudio/fish-speech
+ - icon: fontawesome/brands/qq
+ link: http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=jCKlUP7QgSm9kh95UlBoYv6s1I-Apl1M&authKey=xI5ttVAp3do68IpEYEalwXSYZFdfxZSkah%2BctF5FIMyN2NqAa003vFtLqJyAVRfF&noverify=0&group_code=593946093
+ homepage: https://speech.fish.audio
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..d5acebb1fe5d2fed5842fc3b6b7a5885e04dabef
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,62 @@
+[project]
+name = "fish-speech"
+version = "0.1.0"
+authors = [
+ {name = "Lengyue", email = "lengyue@lengyue.me"},
+]
+description = "Fish Speech"
+readme = "README.md"
+requires-python = ">=3.10"
+keywords = ["TTS", "Speech"]
+license = {text = "CC BY-NC-SA 4.0"}
+classifiers = [
+ "Programming Language :: Python :: 3",
+]
+dependencies = [
+ "numpy<=1.26.4",
+ "transformers>=4.45.2",
+ "datasets==2.18.0",
+ "lightning>=2.1.0",
+ "hydra-core>=1.3.2",
+ "tensorboard>=2.14.1",
+ "natsort>=8.4.0",
+ "einops>=0.7.0",
+ "librosa>=0.10.1",
+ "rich>=13.5.3",
+ "gradio>5.0.0",
+ "wandb>=0.15.11",
+ "grpcio>=1.58.0",
+ "kui>=1.6.0",
+ "uvicorn>=0.30.0",
+ "loguru>=0.6.0",
+ "loralib>=0.1.2",
+ "pyrootutils>=1.0.4",
+ "vector_quantize_pytorch==1.14.24",
+ "resampy>=0.4.3",
+ "einx[torch]==0.2.2",
+ "zstandard>=0.22.0",
+ "pydub",
+ "pyaudio",
+ "faster_whisper",
+ "modelscope==1.17.1",
+ "funasr==1.1.5",
+ "opencc-python-reimplemented==0.1.7",
+ "silero-vad",
+ "ormsgpack",
+ "tiktoken>=0.8.0",
+ "pydantic==2.9.2",
+ "cachetools",
+]
+
+[project.optional-dependencies]
+stable = [
+ "torch<=2.4.1",
+ "torchaudio",
+]
+
+[build-system]
+requires = ["setuptools", "setuptools-scm"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools]
+packages = ["fish_speech", "tools"]
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 0000000000000000000000000000000000000000..ad1493530f7f6d8fa476dbe0b76e6239fce2d7e7
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,6 @@
+{
+ "exclude": [
+ "data",
+ "filelists"
+ ]
+}
diff --git a/run_cmd.bat b/run_cmd.bat
new file mode 100644
index 0000000000000000000000000000000000000000..c2af8a9b6fb75df7b7c81ff5986286845e247fb9
--- /dev/null
+++ b/run_cmd.bat
@@ -0,0 +1,50 @@
+@echo off
+chcp 65001
+
+set no_proxy="127.0.0.1, 0.0.0.0, localhost"
+setlocal enabledelayedexpansion
+
+cd /D "%~dp0"
+
+set PATH="%PATH%";%SystemRoot%\system32
+
+
+echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && (
+ echo.
+ echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && (
+ goto end
+ )
+)
+
+
+set TMP=%CD%\fishenv
+set TEMP=%CD%\fishenv
+
+
+(call conda deactivate && call conda deactivate && call conda deactivate) 2>nul
+
+
+set CONDA_ROOT_PREFIX=%cd%\fishenv\conda
+set INSTALL_ENV_DIR=%cd%\fishenv\env
+
+
+set PYTHONNOUSERSITE=1
+set PYTHONPATH=%~dp0
+set PYTHONHOME=
+
+
+call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%"
+
+if errorlevel 1 (
+ echo.
+ echo Environment activation failed.
+ goto end
+) else (
+ echo.
+ echo Environment activation succeeded.
+)
+
+cmd /k "%*"
+
+:end
+pause
diff --git a/start.bat b/start.bat
new file mode 100644
index 0000000000000000000000000000000000000000..c4e27014f58b30fd0ff8d7e149b4cf431f526f89
--- /dev/null
+++ b/start.bat
@@ -0,0 +1,97 @@
+@echo off
+chcp 65001
+
+set USE_MIRROR=true
+set PYTHONPATH=%~dp0
+set PYTHON_CMD=python
+if exist "fishenv" (
+ set PYTHON_CMD=%cd%\fishenv\env\python
+)
+
+set API_FLAG_PATH=%~dp0API_FLAGS.txt
+set KMP_DUPLICATE_LIB_OK=TRUE
+
+setlocal enabledelayedexpansion
+
+set "HF_ENDPOINT=https://huggingface.co"
+set "no_proxy="
+if "%USE_MIRROR%" == "true" (
+ set "HF_ENDPOINT=https://hf-mirror.com"
+ set "no_proxy=localhost, 127.0.0.1, 0.0.0.0"
+)
+echo "HF_ENDPOINT: !HF_ENDPOINT!"
+echo "NO_PROXY: !no_proxy!"
+
+echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && (
+ echo.
+ echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && (
+ goto end
+ )
+)
+
+%PYTHON_CMD% .\tools\download_models.py
+
+set "API_FLAGS="
+set "flags="
+
+if exist "%API_FLAG_PATH%" (
+ for /f "usebackq tokens=*" %%a in ("%API_FLAG_PATH%") do (
+ set "line=%%a"
+ if not "!line:~0,1!"=="#" (
+ set "line=!line: =!"
+ set "line=!line:\=!"
+ set "line=!line:= !"
+ if not "!line!"=="" (
+ set "API_FLAGS=!API_FLAGS!!line! "
+ )
+ )
+ )
+)
+
+
+if not "!API_FLAGS!"=="" set "API_FLAGS=!API_FLAGS:~0,-1!"
+
+set "flags="
+
+echo !API_FLAGS! | findstr /C:"--api" >nul 2>&1
+if !errorlevel! equ 0 (
+ echo.
+ echo Start HTTP API...
+ set "mode=api"
+ goto process_flags
+)
+
+echo !API_FLAGS! | findstr /C:"--infer" >nul 2>&1
+if !errorlevel! equ 0 (
+ echo.
+ echo Start WebUI Inference...
+ set "mode=infer"
+ goto process_flags
+)
+
+
+:process_flags
+for %%p in (!API_FLAGS!) do (
+ if not "%%p"=="--!mode!" (
+ set "flags=!flags! %%p"
+ )
+)
+
+if not "!flags!"=="" set "flags=!flags:~1!"
+
+echo Debug: flags = !flags!
+
+if "!mode!"=="api" (
+ %PYTHON_CMD% -m tools.api_server !flags!
+) else if "!mode!"=="infer" (
+ %PYTHON_CMD% -m tools.webui !flags!
+)
+
+echo.
+echo Next launch the page...
+%PYTHON_CMD% fish_speech\webui\manage.py
+
+
+:end
+endlocal
+pause
diff --git a/start_webui.sh b/start_webui.sh
new file mode 100644
index 0000000000000000000000000000000000000000..97020c4953e84b437834be714f2e784b98254685
--- /dev/null
+++ b/start_webui.sh
@@ -0,0 +1 @@
+python tools/run_webui.py --compile
diff --git a/tools/__pycache__/api_server.cpython-310.pyc b/tools/__pycache__/api_server.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d24581bb7ed3626023682abd1c338801ffd5033
Binary files /dev/null and b/tools/__pycache__/api_server.cpython-310.pyc differ
diff --git a/tools/__pycache__/e2e_webui.cpython-310.pyc b/tools/__pycache__/e2e_webui.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8c727777b3c1e9c552b5bb876f106dbadc8ab4a
Binary files /dev/null and b/tools/__pycache__/e2e_webui.cpython-310.pyc differ
diff --git a/tools/__pycache__/file.cpython-310.pyc b/tools/__pycache__/file.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2521d95f10c62393062a5b8465628ccf72a07aac
Binary files /dev/null and b/tools/__pycache__/file.cpython-310.pyc differ
diff --git a/tools/__pycache__/fish_e2e.cpython-310.pyc b/tools/__pycache__/fish_e2e.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..524e83d2d8060d32eac06d0db0a2935d37ca5e4a
Binary files /dev/null and b/tools/__pycache__/fish_e2e.cpython-310.pyc differ
diff --git a/tools/__pycache__/schema.cpython-310.pyc b/tools/__pycache__/schema.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61af3bdb31df672f9336b585693b79495145bac9
Binary files /dev/null and b/tools/__pycache__/schema.cpython-310.pyc differ
diff --git a/tools/api_client.py b/tools/api_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..b47a10d5fbc790cf73bb669766fbbba092f0f8c3
--- /dev/null
+++ b/tools/api_client.py
@@ -0,0 +1,221 @@
+import argparse
+import base64
+import wave
+
+import ormsgpack
+import pyaudio
+import requests
+from pydub import AudioSegment
+from pydub.playback import play
+
+from tools.file import audio_to_bytes, read_ref_text
+from tools.schema import ServeReferenceAudio, ServeTTSRequest
+
+
+def parse_args():
+
+ parser = argparse.ArgumentParser(
+ description="Send a WAV file and text to a server and receive synthesized audio.",
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--url",
+ "-u",
+ type=str,
+ default="http://127.0.0.1:8080/v1/tts",
+ help="URL of the server",
+ )
+ parser.add_argument(
+ "--text", "-t", type=str, required=True, help="Text to be synthesized"
+ )
+ parser.add_argument(
+ "--reference_id",
+ "-id",
+ type=str,
+ default=None,
+ help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)",
+ )
+ parser.add_argument(
+ "--reference_audio",
+ "-ra",
+ type=str,
+ nargs="+",
+ default=None,
+ help="Path to the audio file",
+ )
+ parser.add_argument(
+ "--reference_text",
+ "-rt",
+ type=str,
+ nargs="+",
+ default=None,
+ help="Reference text for voice synthesis",
+ )
+ parser.add_argument(
+ "--output",
+ "-o",
+ type=str,
+ default="generated_audio",
+ help="Output audio file name",
+ )
+ parser.add_argument(
+ "--play",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Whether to play audio after receiving data",
+ )
+ parser.add_argument("--normalize", type=bool, default=True)
+ parser.add_argument(
+ "--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
+ )
+ parser.add_argument(
+ "--latency",
+ type=str,
+ default="normal",
+ choices=["normal", "balanced"],
+ help="Used in api.fish.audio/v1/tts",
+ )
+ parser.add_argument(
+ "--max_new_tokens",
+ type=int,
+ default=1024,
+ help="Maximum new tokens to generate. \n0 means no limit.",
+ )
+ parser.add_argument(
+ "--chunk_length", type=int, default=200, help="Chunk length for synthesis"
+ )
+ parser.add_argument(
+ "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
+ )
+ parser.add_argument(
+ "--repetition_penalty",
+ type=float,
+ default=1.2,
+ help="Repetition penalty for synthesis",
+ )
+ parser.add_argument(
+ "--temperature", type=float, default=0.7, help="Temperature for sampling"
+ )
+
+ parser.add_argument(
+ "--streaming", type=bool, default=False, help="Enable streaming response"
+ )
+ parser.add_argument(
+ "--channels", type=int, default=1, help="Number of audio channels"
+ )
+ parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
+ parser.add_argument(
+ "--use_memory_cache",
+ type=str,
+ default="off",
+ choices=["on", "off"],
+ help="Cache encoded references codes in memory.\n",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="`None` means randomized inference, otherwise deterministic.\n"
+ "It can't be used for fixing a timbre.",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+
+ args = parse_args()
+
+ idstr: str | None = args.reference_id
+ # priority: ref_id > [{text, audio},...]
+ if idstr is None:
+ ref_audios = args.reference_audio
+ ref_texts = args.reference_text
+ if ref_audios is None:
+ byte_audios = []
+ else:
+ byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
+ if ref_texts is None:
+ ref_texts = []
+ else:
+ ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
+ else:
+ byte_audios = []
+ ref_texts = []
+ pass # in api.py
+
+ data = {
+ "text": args.text,
+ "references": [
+ ServeReferenceAudio(
+ audio=ref_audio if ref_audio is not None else b"", text=ref_text
+ )
+ for ref_text, ref_audio in zip(ref_texts, byte_audios)
+ ],
+ "reference_id": idstr,
+ "normalize": args.normalize,
+ "format": args.format,
+ "max_new_tokens": args.max_new_tokens,
+ "chunk_length": args.chunk_length,
+ "top_p": args.top_p,
+ "repetition_penalty": args.repetition_penalty,
+ "temperature": args.temperature,
+ "streaming": args.streaming,
+ "use_memory_cache": args.use_memory_cache,
+ "seed": args.seed,
+ }
+
+ pydantic_data = ServeTTSRequest(**data)
+
+ response = requests.post(
+ args.url,
+ data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
+ stream=args.streaming,
+ headers={
+ "authorization": "Bearer YOUR_API_KEY",
+ "content-type": "application/msgpack",
+ },
+ )
+
+ if response.status_code == 200:
+ if args.streaming:
+ p = pyaudio.PyAudio()
+ audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
+ stream = p.open(
+ format=audio_format, channels=args.channels, rate=args.rate, output=True
+ )
+
+ wf = wave.open(f"{args.output}.wav", "wb")
+ wf.setnchannels(args.channels)
+ wf.setsampwidth(p.get_sample_size(audio_format))
+ wf.setframerate(args.rate)
+
+ stream_stopped_flag = False
+
+ try:
+ for chunk in response.iter_content(chunk_size=1024):
+ if chunk:
+ stream.write(chunk)
+ wf.writeframesraw(chunk)
+ else:
+ if not stream_stopped_flag:
+ stream.stop_stream()
+ stream_stopped_flag = True
+ finally:
+ stream.close()
+ p.terminate()
+ wf.close()
+ else:
+ audio_content = response.content
+ audio_path = f"{args.output}.{args.format}"
+ with open(audio_path, "wb") as audio_file:
+ audio_file.write(audio_content)
+
+ audio = AudioSegment.from_file(audio_path, format=args.format)
+ if args.play:
+ play(audio)
+ print(f"Audio has been saved to '{audio_path}'.")
+ else:
+ print(f"Request failed with status code {response.status_code}")
+ print(response.json())
diff --git a/tools/api_server.py b/tools/api_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b5d26fccc30a7d06ef8a263247f724c788f9a5c
--- /dev/null
+++ b/tools/api_server.py
@@ -0,0 +1,98 @@
+from threading import Lock
+
+import pyrootutils
+import uvicorn
+from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes
+from loguru import logger
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from tools.server.api_utils import MsgPackRequest, parse_args
+from tools.server.exception_handler import ExceptionHandler
+from tools.server.model_manager import ModelManager
+from tools.server.views import (
+ ASRView,
+ ChatView,
+ HealthView,
+ TTSView,
+ VQGANDecodeView,
+ VQGANEncodeView,
+)
+
+
+class API(ExceptionHandler):
+ def __init__(self):
+ self.args = parse_args()
+ self.routes = [
+ ("/v1/health", HealthView),
+ ("/v1/vqgan/encode", VQGANEncodeView),
+ ("/v1/vqgan/decode", VQGANDecodeView),
+ ("/v1/asr", ASRView),
+ ("/v1/tts", TTSView),
+ ("/v1/chat", ChatView),
+ ]
+ self.routes = Routes([HttpRoute(path, view) for path, view in self.routes])
+
+ self.openapi = OpenAPI(
+ {
+ "title": "Fish Speech API",
+ "version": "1.5.0",
+ },
+ ).routes
+
+ # Initialize the app
+ self.app = Kui(
+ routes=self.routes + self.openapi[1:], # Remove the default route
+ exception_handlers={
+ HTTPException: self.http_exception_handler,
+ Exception: self.other_exception_handler,
+ },
+ factory_class=FactoryClass(http=MsgPackRequest),
+ cors_config={},
+ )
+
+ # Add the state variables
+ self.app.state.lock = Lock()
+ self.app.state.device = self.args.device
+ self.app.state.max_text_length = self.args.max_text_length
+
+ # Associate the app with the model manager
+ self.app.on_startup(self.initialize_app)
+
+ async def initialize_app(self, app: Kui):
+ # Make the ModelManager available to the views
+ app.state.model_manager = ModelManager(
+ mode=self.args.mode,
+ device=self.args.device,
+ half=self.args.half,
+ compile=self.args.compile,
+ asr_enabled=self.args.load_asr_model,
+ llama_checkpoint_path=self.args.llama_checkpoint_path,
+ decoder_checkpoint_path=self.args.decoder_checkpoint_path,
+ decoder_config_name=self.args.decoder_config_name,
+ )
+
+ logger.info(f"Startup done, listening server at http://{self.args.listen}")
+
+
+# Each worker process created by Uvicorn has its own memory space,
+# meaning that models and variables are not shared between processes.
+# Therefore, any variables (like `llama_queue` or `decoder_model`)
+# will not be shared across workers.
+
+# Multi-threading for deep learning can cause issues, such as inconsistent
+# outputs if multiple threads access the same buffers simultaneously.
+# Instead, it's better to use multiprocessing or independent models per thread.
+
+if __name__ == "__main__":
+
+ api = API()
+ host, port = api.args.listen.split(":")
+
+ uvicorn.run(
+ api.app,
+ host=host,
+ port=int(port),
+ workers=api.args.workers,
+ log_level="info",
+ )
diff --git a/tools/download_models.py b/tools/download_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..e14a0991698c13635eceaffbbd7d2a47ff05138d
--- /dev/null
+++ b/tools/download_models.py
@@ -0,0 +1,55 @@
+import os
+
+from huggingface_hub import hf_hub_download
+
+
+# Download
+def check_and_download_files(repo_id, file_list, local_dir):
+ os.makedirs(local_dir, exist_ok=True)
+ for file in file_list:
+ file_path = os.path.join(local_dir, file)
+ if not os.path.exists(file_path):
+ print(f"{file} 不存在,从 Hugging Face 仓库下载...")
+ hf_hub_download(
+ repo_id=repo_id,
+ filename=file,
+ resume_download=True,
+ local_dir=local_dir,
+ local_dir_use_symlinks=False,
+ )
+ else:
+ print(f"{file} 已存在,跳过下载。")
+
+
+# 1st
+repo_id_1 = "fishaudio/fish-speech-1.5"
+local_dir_1 = "./checkpoints/fish-speech-1.5"
+files_1 = [
+ "gitattributes",
+ "model.pth",
+ "README.md",
+ "special_tokens.json",
+ "tokenizer.tiktoken",
+ "config.json",
+ "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+]
+
+# 3rd
+repo_id_3 = "fishaudio/fish-speech-1"
+local_dir_3 = "./"
+files_3 = [
+ "ffmpeg.exe",
+ "ffprobe.exe",
+]
+
+# 4th
+repo_id_4 = "SpicyqSama007/fish-speech-packed"
+local_dir_4 = "./"
+files_4 = [
+ "asr-label-win-x64.exe",
+]
+
+check_and_download_files(repo_id_1, files_1, local_dir_1)
+
+check_and_download_files(repo_id_3, files_3, local_dir_3)
+check_and_download_files(repo_id_4, files_4, local_dir_4)
diff --git a/tools/e2e_webui.py b/tools/e2e_webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..37474fbd5645c09fcbe6caac1331672614c5c821
--- /dev/null
+++ b/tools/e2e_webui.py
@@ -0,0 +1,232 @@
+import io
+import re
+import wave
+
+import gradio as gr
+import numpy as np
+
+from .fish_e2e import FishE2EAgent, FishE2EEventType
+from .schema import ServeMessage, ServeTextPart, ServeVQPart
+
+
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+ buffer = io.BytesIO()
+
+ with wave.open(buffer, "wb") as wav_file:
+ wav_file.setnchannels(channels)
+ wav_file.setsampwidth(bit_depth // 8)
+ wav_file.setframerate(sample_rate)
+
+ wav_header_bytes = buffer.getvalue()
+ buffer.close()
+ return wav_header_bytes
+
+
+class ChatState:
+ def __init__(self):
+ self.conversation = []
+ self.added_systext = False
+ self.added_sysaudio = False
+
+ def get_history(self):
+ results = []
+ for msg in self.conversation:
+ results.append({"role": msg.role, "content": self.repr_message(msg)})
+
+ # Process assistant messages to extract questions and update user messages
+ for i, msg in enumerate(results):
+ if msg["role"] == "assistant":
+ match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
+ if match and i > 0 and results[i - 1]["role"] == "user":
+ # Update previous user message with extracted question
+ results[i - 1]["content"] += "\n" + match.group(1)
+ # Remove the Question/Answer format from assistant message
+ msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
+ return results
+
+ def repr_message(self, msg: ServeMessage):
+ response = ""
+ for part in msg.parts:
+ if isinstance(part, ServeTextPart):
+ response += part.text
+ elif isinstance(part, ServeVQPart):
+ response += f""
+ return response
+
+
+def clear_fn():
+ return [], ChatState(), None, None, None
+
+
+async def process_audio_input(
+ sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
+):
+ if audio_input is None and not text_input:
+ raise gr.Error("No input provided")
+
+ agent = FishE2EAgent() # Create new agent instance for each request
+
+ # Convert audio input to numpy array
+ if isinstance(audio_input, tuple):
+ sr, audio_data = audio_input
+ elif text_input:
+ sr = 44100
+ audio_data = None
+ else:
+ raise gr.Error("Invalid audio format")
+
+ if isinstance(sys_audio_input, tuple):
+ sr, sys_audio_data = sys_audio_input
+ else:
+ sr = 44100
+ sys_audio_data = None
+
+ def append_to_chat_ctx(
+ part: ServeTextPart | ServeVQPart, role: str = "assistant"
+ ) -> None:
+ if not state.conversation or state.conversation[-1].role != role:
+ state.conversation.append(ServeMessage(role=role, parts=[part]))
+ else:
+ state.conversation[-1].parts.append(part)
+
+ if state.added_systext is False and sys_text_input:
+ state.added_systext = True
+ append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
+ if text_input:
+ append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
+ audio_data = None
+
+ result_audio = b""
+ async for event in agent.stream(
+ sys_audio_data,
+ audio_data,
+ sr,
+ 1,
+ chat_ctx={
+ "messages": state.conversation,
+ "added_sysaudio": state.added_sysaudio,
+ },
+ ):
+ if event.type == FishE2EEventType.USER_CODES:
+ append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
+ elif event.type == FishE2EEventType.SPEECH_SEGMENT:
+ append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
+ yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
+ elif event.type == FishE2EEventType.TEXT_SEGMENT:
+ append_to_chat_ctx(ServeTextPart(text=event.text))
+ yield state.get_history(), None, None, None
+
+ yield state.get_history(), None, None, None
+
+
+async def process_text_input(
+ sys_audio_input, sys_text_input, state: ChatState, text_input: str
+):
+ async for event in process_audio_input(
+ sys_audio_input, sys_text_input, None, state, text_input
+ ):
+ yield event
+
+
+def create_demo():
+ with gr.Blocks() as demo:
+ state = gr.State(ChatState())
+
+ with gr.Row():
+ # Left column (70%) for chatbot and notes
+ with gr.Column(scale=7):
+ chatbot = gr.Chatbot(
+ [],
+ elem_id="chatbot",
+ bubble_full_width=False,
+ height=600,
+ type="messages",
+ )
+
+ # notes = gr.Markdown(
+ # """
+ # # Fish Agent
+ # 1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
+ # 2. 你可以在我们的官方仓库找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
+ # 3. Demo为早期灰度测试版本,推理速度尚待优化.
+ # # 特色
+ # 1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
+ # 2. 模型可以使用reference audio控制说话音色.
+ # 3. 可以生成具有较强情感与韵律的音频.
+ # """
+ # )
+ notes = gr.Markdown(
+ """
+ # Fish Agent
+ 1. This demo is Fish Audio's self-researh end-to-end language model, Fish Agent version 3B.
+ 2. You can find the code and weights in our official repo in [gitub](https://github.com/fishaudio/fish-speech) and [hugging face](https://huggingface.co/fishaudio/fish-agent-v0.1-3b), but the content is released under a CC BY-NC-SA 4.0 licence.
+ 3. The demo is an early alpha test version, the inference speed needs to be optimised.
+ # Features
+ 1. The model automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
+ 2. The model can use reference audio to control the speech timbre.
+ 3. The model can generate speech with strong emotion.
+ """
+ )
+
+ # Right column (30%) for controls
+ with gr.Column(scale=3):
+ sys_audio_input = gr.Audio(
+ sources=["upload"],
+ type="numpy",
+ label="Give a timbre for your assistant",
+ )
+ sys_text_input = gr.Textbox(
+ label="What is your assistant's role?",
+ value="You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user's speech, then answer it in the following format: 'Question: [USER_SPEECH]\n\nAnswer: [YOUR_RESPONSE]\n'. You are required to use the following voice in this conversation.",
+ type="text",
+ )
+ audio_input = gr.Audio(
+ sources=["microphone"], type="numpy", label="Speak your message"
+ )
+
+ text_input = gr.Textbox(label="Or type your message", type="text")
+
+ output_audio = gr.Audio(
+ label="Assistant's Voice",
+ streaming=True,
+ autoplay=True,
+ interactive=False,
+ )
+
+ send_button = gr.Button("Send", variant="primary")
+ clear_button = gr.Button("Clear")
+
+ # Event handlers
+ audio_input.stop_recording(
+ process_audio_input,
+ inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
+ outputs=[chatbot, output_audio, audio_input, text_input],
+ show_progress=True,
+ )
+
+ send_button.click(
+ process_text_input,
+ inputs=[sys_audio_input, sys_text_input, state, text_input],
+ outputs=[chatbot, output_audio, audio_input, text_input],
+ show_progress=True,
+ )
+
+ text_input.submit(
+ process_text_input,
+ inputs=[sys_audio_input, sys_text_input, state, text_input],
+ outputs=[chatbot, output_audio, audio_input, text_input],
+ show_progress=True,
+ )
+
+ clear_button.click(
+ clear_fn,
+ inputs=[],
+ outputs=[chatbot, state, audio_input, output_audio, text_input],
+ )
+
+ return demo
+
+
+if __name__ == "__main__":
+ demo = create_demo()
+ demo.launch(server_name="127.0.0.1", server_port=7860, share=True)
diff --git a/tools/extract_model.py b/tools/extract_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..97fe62507b7282890319d8dc1eaa3cbca0e1f60a
--- /dev/null
+++ b/tools/extract_model.py
@@ -0,0 +1,21 @@
+import click
+import torch
+from loguru import logger
+
+
+@click.command()
+@click.argument("model_path")
+@click.argument("output_path")
+def main(model_path, output_path):
+ if model_path == output_path:
+ logger.error("Model path and output path are the same")
+ return
+
+ logger.info(f"Loading model from {model_path}")
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
+ torch.save(state_dict, output_path)
+ logger.info(f"Model saved to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/file.py b/tools/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7a0597365252e7aecf887897ff391a061275c3f
--- /dev/null
+++ b/tools/file.py
@@ -0,0 +1,125 @@
+import base64
+from pathlib import Path
+from typing import Union
+
+from loguru import logger
+from natsort import natsorted
+
+AUDIO_EXTENSIONS = {
+ ".mp3",
+ ".wav",
+ ".flac",
+ ".ogg",
+ ".m4a",
+ ".wma",
+ ".aac",
+ ".aiff",
+ ".aif",
+ ".aifc",
+}
+
+VIDEO_EXTENSIONS = {
+ ".mp4",
+ ".avi",
+}
+
+
+def audio_to_bytes(file_path):
+ if not file_path or not Path(file_path).exists():
+ return None
+ with open(file_path, "rb") as wav_file:
+ wav = wav_file.read()
+ return wav
+
+
+def read_ref_text(ref_text):
+ path = Path(ref_text)
+ if path.exists() and path.is_file():
+ with path.open("r", encoding="utf-8") as file:
+ return file.read()
+ return ref_text
+
+
+def list_files(
+ path: Union[Path, str],
+ extensions: set[str] = None,
+ recursive: bool = False,
+ sort: bool = True,
+) -> list[Path]:
+ """List files in a directory.
+
+ Args:
+ path (Path): Path to the directory.
+ extensions (set, optional): Extensions to filter. Defaults to None.
+ recursive (bool, optional): Whether to search recursively. Defaults to False.
+ sort (bool, optional): Whether to sort the files. Defaults to True.
+
+ Returns:
+ list: List of files.
+ """
+
+ if isinstance(path, str):
+ path = Path(path)
+
+ if not path.exists():
+ raise FileNotFoundError(f"Directory {path} does not exist.")
+
+ files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
+
+ if sort:
+ files = natsorted(files)
+
+ return files
+
+
+def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
+ """
+ Load a Bert-VITS2 style filelist.
+ """
+
+ files = set()
+ results = []
+ count_duplicated, count_not_found = 0, 0
+
+ LANGUAGE_TO_LANGUAGES = {
+ "zh": ["zh", "en"],
+ "jp": ["jp", "en"],
+ "en": ["en"],
+ }
+
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ splits = line.strip().split("|", maxsplit=3)
+ if len(splits) != 4:
+ logger.warning(f"Invalid line: {line}")
+ continue
+
+ filename, speaker, language, text = splits
+ file = Path(filename)
+ language = language.strip().lower()
+
+ if language == "ja":
+ language = "jp"
+
+ assert language in ["zh", "jp", "en"], f"Invalid language {language}"
+ languages = LANGUAGE_TO_LANGUAGES[language]
+
+ if file in files:
+ logger.warning(f"Duplicated file: {file}")
+ count_duplicated += 1
+ continue
+
+ if not file.exists():
+ logger.warning(f"File not found: {file}")
+ count_not_found += 1
+ continue
+
+ results.append((file, speaker, languages, text))
+
+ if count_duplicated > 0:
+ logger.warning(f"Total duplicated files: {count_duplicated}")
+
+ if count_not_found > 0:
+ logger.warning(f"Total files not found: {count_not_found}")
+
+ return results
diff --git a/tools/fish_e2e.py b/tools/fish_e2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a44fca3462e29ef944c1b4b3779e413a5c64689
--- /dev/null
+++ b/tools/fish_e2e.py
@@ -0,0 +1,298 @@
+import base64
+import ctypes
+import io
+import json
+import os
+import struct
+from dataclasses import dataclass
+from enum import Enum
+from typing import AsyncGenerator, Union
+
+import httpx
+import numpy as np
+import ormsgpack
+import soundfile as sf
+
+from .schema import (
+ ServeChatRequest,
+ ServeMessage,
+ ServeTextPart,
+ ServeVQGANDecodeRequest,
+ ServeVQGANEncodeRequest,
+ ServeVQPart,
+)
+
+
+class CustomAudioFrame:
+ def __init__(self, data, sample_rate, num_channels, samples_per_channel):
+ if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
+ ctypes.c_int16
+ ):
+ raise ValueError(
+ "data length must be >= num_channels * samples_per_channel * sizeof(int16)"
+ )
+
+ self._data = bytearray(data)
+ self._sample_rate = sample_rate
+ self._num_channels = num_channels
+ self._samples_per_channel = samples_per_channel
+
+ @property
+ def data(self):
+ return memoryview(self._data).cast("h")
+
+ @property
+ def sample_rate(self):
+ return self._sample_rate
+
+ @property
+ def num_channels(self):
+ return self._num_channels
+
+ @property
+ def samples_per_channel(self):
+ return self._samples_per_channel
+
+ @property
+ def duration(self):
+ return self.samples_per_channel / self.sample_rate
+
+ def __repr__(self):
+ return (
+ f"CustomAudioFrame(sample_rate={self.sample_rate}, "
+ f"num_channels={self.num_channels}, "
+ f"samples_per_channel={self.samples_per_channel}, "
+ f"duration={self.duration:.3f})"
+ )
+
+
+class FishE2EEventType(Enum):
+ SPEECH_SEGMENT = 1
+ TEXT_SEGMENT = 2
+ END_OF_TEXT = 3
+ END_OF_SPEECH = 4
+ ASR_RESULT = 5
+ USER_CODES = 6
+
+
+@dataclass
+class FishE2EEvent:
+ type: FishE2EEventType
+ frame: np.ndarray = None
+ text: str = None
+ vq_codes: list[list[int]] = None
+
+
+client = httpx.AsyncClient(
+ timeout=None,
+ limits=httpx.Limits(
+ max_connections=None,
+ max_keepalive_connections=None,
+ keepalive_expiry=None,
+ ),
+)
+
+
+class FishE2EAgent:
+ def __init__(self):
+ self.llm_url = "http://localhost:8080/v1/chat"
+ self.vqgan_url = "http://localhost:8080"
+ self.client = httpx.AsyncClient(timeout=None)
+
+ async def get_codes(self, audio_data, sample_rate):
+ audio_buffer = io.BytesIO()
+ sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
+ audio_buffer.seek(0)
+ # Step 1: Encode audio using VQGAN
+ encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()])
+ encode_request_bytes = ormsgpack.packb(
+ encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
+ )
+ encode_response = await self.client.post(
+ f"{self.vqgan_url}/v1/vqgan/encode",
+ data=encode_request_bytes,
+ headers={"Content-Type": "application/msgpack"},
+ )
+ encode_response_data = ormsgpack.unpackb(encode_response.content)
+ codes = encode_response_data["tokens"][0]
+ return codes
+
+ async def stream(
+ self,
+ system_audio_data: np.ndarray | None,
+ user_audio_data: np.ndarray | None,
+ sample_rate: int,
+ num_channels: int,
+ chat_ctx: dict | None = None,
+ ) -> AsyncGenerator[bytes, None]:
+
+ if system_audio_data is not None:
+ sys_codes = await self.get_codes(system_audio_data, sample_rate)
+ else:
+ sys_codes = None
+ if user_audio_data is not None:
+ user_codes = await self.get_codes(user_audio_data, sample_rate)
+ # Step 2: Prepare LLM request
+ if chat_ctx is None:
+ sys_parts = [
+ ServeTextPart(
+ text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。'
+ ),
+ ]
+ if system_audio_data is not None:
+ sys_parts.append(ServeVQPart(codes=sys_codes))
+ chat_ctx = {
+ "messages": [
+ ServeMessage(
+ role="system",
+ parts=sys_parts,
+ ),
+ ],
+ }
+ else:
+ if chat_ctx["added_sysaudio"] is False and sys_codes:
+ chat_ctx["added_sysaudio"] = True
+ chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes))
+
+ prev_messages = chat_ctx["messages"].copy()
+ if user_audio_data is not None:
+ yield FishE2EEvent(
+ type=FishE2EEventType.USER_CODES,
+ vq_codes=user_codes,
+ )
+ else:
+ user_codes = None
+
+ request = ServeChatRequest(
+ messages=prev_messages
+ + (
+ [
+ ServeMessage(
+ role="user",
+ parts=[ServeVQPart(codes=user_codes)],
+ )
+ ]
+ if user_codes
+ else []
+ ),
+ streaming=True,
+ num_samples=1,
+ )
+
+ # Step 3: Stream LLM response and decode audio
+ buffer = b""
+ vq_codes = []
+ current_vq = False
+
+ async def decode_send():
+ nonlocal current_vq
+ nonlocal vq_codes
+
+ data = np.concatenate(vq_codes, axis=1).tolist()
+ # Decode VQ codes to audio
+ decode_request = ServeVQGANDecodeRequest(tokens=[data])
+ decode_response = await self.client.post(
+ f"{self.vqgan_url}/v1/vqgan/decode",
+ data=ormsgpack.packb(
+ decode_request,
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+ ),
+ headers={"Content-Type": "application/msgpack"},
+ )
+ decode_data = ormsgpack.unpackb(decode_response.content)
+
+ # Convert float16 audio data to int16
+ audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
+ audio_data = (audio_data * 32768).astype(np.int16).tobytes()
+
+ audio_frame = CustomAudioFrame(
+ data=audio_data,
+ samples_per_channel=len(audio_data) // 2,
+ sample_rate=44100,
+ num_channels=1,
+ )
+ yield FishE2EEvent(
+ type=FishE2EEventType.SPEECH_SEGMENT,
+ frame=audio_frame,
+ vq_codes=data,
+ )
+
+ current_vq = False
+ vq_codes = []
+
+ async with self.client.stream(
+ "POST",
+ self.llm_url,
+ data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
+ headers={"Content-Type": "application/msgpack"},
+ ) as response:
+
+ async for chunk in response.aiter_bytes():
+ buffer += chunk
+
+ while len(buffer) >= 4:
+ read_length = struct.unpack("I", buffer[:4])[0]
+ if len(buffer) < 4 + read_length:
+ break
+
+ body = buffer[4 : 4 + read_length]
+ buffer = buffer[4 + read_length :]
+ data = ormsgpack.unpackb(body)
+
+ if data["delta"] and data["delta"]["part"]:
+ if current_vq and data["delta"]["part"]["type"] == "text":
+ async for event in decode_send():
+ yield event
+ if data["delta"]["part"]["type"] == "text":
+ yield FishE2EEvent(
+ type=FishE2EEventType.TEXT_SEGMENT,
+ text=data["delta"]["part"]["text"],
+ )
+ elif data["delta"]["part"]["type"] == "vq":
+ vq_codes.append(np.array(data["delta"]["part"]["codes"]))
+ current_vq = True
+
+ if current_vq and vq_codes:
+ async for event in decode_send():
+ yield event
+
+ yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT)
+ yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH)
+
+
+# Example usage:
+async def main():
+ import torchaudio
+
+ agent = FishE2EAgent()
+
+ # Replace this with actual audio data loading
+ with open("uz_story_en.m4a", "rb") as f:
+ audio_data = f.read()
+
+ audio_data, sample_rate = torchaudio.load("uz_story_en.m4a")
+ audio_data = (audio_data.numpy() * 32768).astype(np.int16)
+
+ stream = agent.stream(audio_data, sample_rate, 1)
+ if os.path.exists("audio_segment.wav"):
+ os.remove("audio_segment.wav")
+
+ async for event in stream:
+ if event.type == FishE2EEventType.SPEECH_SEGMENT:
+ # Handle speech segment (e.g., play audio or save to file)
+ with open("audio_segment.wav", "ab+") as f:
+ f.write(event.frame.data)
+ elif event.type == FishE2EEventType.ASR_RESULT:
+ print(event.text, flush=True)
+ elif event.type == FishE2EEventType.TEXT_SEGMENT:
+ print(event.text, flush=True, end="")
+ elif event.type == FishE2EEventType.END_OF_TEXT:
+ print("\nEnd of text reached.")
+ elif event.type == FishE2EEventType.END_OF_SPEECH:
+ print("End of speech reached.")
+
+
+if __name__ == "__main__":
+ import asyncio
+
+ asyncio.run(main())
diff --git a/tools/inference_engine/__init__.py b/tools/inference_engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c3e476c0383e48f7d400f08f6c7db61084e3655
--- /dev/null
+++ b/tools/inference_engine/__init__.py
@@ -0,0 +1,192 @@
+import gc
+import queue
+from typing import Generator
+
+import numpy as np
+import torch
+from loguru import logger
+
+from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+from fish_speech.utils import autocast_exclude_mps, set_seed
+from tools.inference_engine.reference_loader import ReferenceLoader
+from tools.inference_engine.utils import InferenceResult, wav_chunk_header
+from tools.inference_engine.vq_manager import VQManager
+from tools.llama.generate import (
+ GenerateRequest,
+ GenerateResponse,
+ WrappedGenerateResponse,
+)
+from tools.schema import ServeTTSRequest
+
+
+class TTSInferenceEngine(ReferenceLoader, VQManager):
+
+ def __init__(
+ self,
+ llama_queue: queue.Queue,
+ decoder_model: FireflyArchitecture,
+ precision: torch.dtype,
+ compile: bool,
+ ) -> None:
+
+ super().__init__()
+
+ self.llama_queue = llama_queue
+ self.decoder_model = decoder_model
+ self.precision = precision
+ self.compile = compile
+
+ @torch.inference_mode()
+ def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
+ """
+ Main inference function:
+ - Loads the reference audio and text.
+ - Calls the LLAMA model for inference.
+ - Decodes the VQ tokens to audio.
+ """
+
+ ref_id: str | None = req.reference_id
+ prompt_tokens, prompt_texts = [], []
+ # Load the reference audio and text based on id or hash
+ if ref_id is not None:
+ prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
+
+ elif req.references:
+ prompt_tokens, prompt_texts = self.load_by_hash(
+ req.references, req.use_memory_cache
+ )
+
+ # Set the random seed if provided
+ if req.seed is not None:
+ set_seed(req.seed)
+ logger.warning(f"set seed: {req.seed}")
+
+ # Get the symbolic tokens from the LLAMA model
+ response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
+
+ # Get the sample rate from the decoder model
+ sample_rate = self.decoder_model.spec_transform.sample_rate
+
+ # If streaming, send the header
+ if req.streaming:
+ yield InferenceResult(
+ code="header",
+ audio=(sample_rate, wav_chunk_header(sample_rate=sample_rate)),
+ error=None,
+ )
+
+ segments = []
+
+ while True:
+ # Get the response from the LLAMA model
+ wrapped_result: WrappedGenerateResponse = response_queue.get()
+ if wrapped_result.status == "error":
+ yield InferenceResult(
+ code="error",
+ audio=None,
+ error=(
+ wrapped_result.response
+ if isinstance(wrapped_result.response, Exception)
+ else Exception("Unknown error")
+ ),
+ )
+ break
+
+ # Check the response type
+ if not isinstance(wrapped_result.response, GenerateResponse):
+ raise TypeError(
+ "Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
+ )
+
+ result: GenerateResponse = wrapped_result.response
+ if result.action != "next":
+ segment = self.get_audio_segment(result)
+
+ if req.streaming: # Used only by the API server
+ yield InferenceResult(
+ code="segment",
+ audio=(sample_rate, segment),
+ error=None,
+ )
+ segments.append(segment)
+ else:
+ break
+
+ # Clean up the memory
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ # Edge case: no audio generated
+ if len(segments) == 0:
+ yield InferenceResult(
+ code="error",
+ audio=None,
+ error=RuntimeError("No audio generated, please check the input text."),
+ )
+ else:
+ # Streaming or not, return the final audio
+ audio = np.concatenate(segments, axis=0)
+ yield InferenceResult(
+ code="final",
+ audio=(sample_rate, audio),
+ error=None,
+ )
+
+ return None
+
+ def send_Llama_request(
+ self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
+ ) -> queue.Queue:
+ """
+ Send a request to the LLAMA model to generate the symbolic tokens.
+ """
+
+ # Prepare the request
+ request = dict(
+ device=self.decoder_model.device,
+ max_new_tokens=req.max_new_tokens,
+ text=(
+ req.text
+ if not req.normalize
+ else ChnNormedText(raw_text=req.text).normalize()
+ ),
+ top_p=req.top_p,
+ repetition_penalty=req.repetition_penalty,
+ temperature=req.temperature,
+ compile=self.compile,
+ iterative_prompt=req.chunk_length > 0,
+ chunk_length=req.chunk_length,
+ max_length=4096,
+ prompt_tokens=prompt_tokens,
+ prompt_text=prompt_texts,
+ )
+
+ # Create a queue to get the response
+ response_queue = queue.Queue()
+
+ # Send the request to the LLAMA model
+ self.llama_queue.put(
+ GenerateRequest(
+ request=request,
+ response_queue=response_queue,
+ )
+ )
+
+ return response_queue
+
+ def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
+ """
+ Decode the VQ tokens to audio.
+ """
+
+ # Don't use autocast on MPS devices
+ with autocast_exclude_mps(
+ device_type=self.decoder_model.device.type, dtype=self.precision
+ ):
+ # Decode the symbolic tokens to audio
+ segment = self.decode_vq_tokens(codes=result.codes)
+
+ # Convert the audio to numpy
+ return segment.float().cpu().numpy()
diff --git a/tools/inference_engine/__pycache__/__init__.cpython-310.pyc b/tools/inference_engine/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48fd5bf525a88cc02a2d4f4b65990110d7bc2f50
Binary files /dev/null and b/tools/inference_engine/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tools/inference_engine/__pycache__/reference_loader.cpython-310.pyc b/tools/inference_engine/__pycache__/reference_loader.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2becba5657b05a138f8a454edc515bedddec3c8
Binary files /dev/null and b/tools/inference_engine/__pycache__/reference_loader.cpython-310.pyc differ
diff --git a/tools/inference_engine/__pycache__/utils.cpython-310.pyc b/tools/inference_engine/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f14100df47f5f1e7674e030afba84d0d23112279
Binary files /dev/null and b/tools/inference_engine/__pycache__/utils.cpython-310.pyc differ
diff --git a/tools/inference_engine/__pycache__/vq_manager.cpython-310.pyc b/tools/inference_engine/__pycache__/vq_manager.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0cea7a0fd61c55d11d5a37404cd3cc24fa05d33a
Binary files /dev/null and b/tools/inference_engine/__pycache__/vq_manager.cpython-310.pyc differ
diff --git a/tools/inference_engine/reference_loader.py b/tools/inference_engine/reference_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f19c6dd7ec34bd6ed6ce9110c772e24c5a2b282
--- /dev/null
+++ b/tools/inference_engine/reference_loader.py
@@ -0,0 +1,125 @@
+import io
+from hashlib import sha256
+from pathlib import Path
+from typing import Callable, Literal, Tuple
+
+import torch
+import torchaudio
+from loguru import logger
+
+from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
+from tools.schema import ServeReferenceAudio
+
+
+class ReferenceLoader:
+
+ def __init__(self) -> None:
+ """
+ Component of the TTSInferenceEngine class.
+ Loads and manages the cache for the reference audio and text.
+ """
+ self.ref_by_id: dict = {}
+ self.ref_by_hash: dict = {}
+
+ # Make Pylance happy (attribut/method not defined...)
+ self.decoder_model: FireflyArchitecture
+ self.encode_reference: Callable
+
+ # Define the torchaudio backend
+ backends = torchaudio.list_audio_backends()
+ if "ffmpeg" in backends:
+ self.backend = "ffmpeg"
+ else:
+ self.backend = "soundfile"
+
+ def load_by_id(
+ self,
+ id: str,
+ use_cache: Literal["on", "off"],
+ ) -> Tuple:
+
+ # Load the references audio and text by id
+ ref_folder = Path("references") / id
+ ref_folder.mkdir(parents=True, exist_ok=True)
+ ref_audios = list_files(
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
+ )
+
+ if use_cache == "off" or id not in self.ref_by_id:
+ # If the references are not already loaded, encode them
+ prompt_tokens = [
+ self.encode_reference(
+ # decoder_model=self.decoder_model,
+ reference_audio=audio_to_bytes(str(ref_audio)),
+ enable_reference_audio=True,
+ )
+ for ref_audio in ref_audios
+ ]
+ prompt_texts = [
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
+ for ref_audio in ref_audios
+ ]
+ self.ref_by_id[id] = (prompt_tokens, prompt_texts)
+
+ else:
+ # Reuse already encoded references
+ logger.info("Use same references")
+ prompt_tokens, prompt_texts = self.ref_by_id[id]
+
+ return prompt_tokens, prompt_texts
+
+ def load_by_hash(
+ self,
+ references: list[ServeReferenceAudio],
+ use_cache: Literal["on", "off"],
+ ) -> Tuple:
+
+ # Load the references audio and text by hash
+ audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
+
+ cache_used = False
+ prompt_tokens, prompt_texts = [], []
+ for i, ref in enumerate(references):
+ if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
+ # If the references are not already loaded, encode them
+ prompt_tokens.append(
+ self.encode_reference(
+ reference_audio=ref.audio,
+ enable_reference_audio=True,
+ )
+ )
+ prompt_texts.append(ref.text)
+ self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
+
+ else:
+ # Reuse already encoded references
+ prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]]
+ cache_used = True
+
+ if cache_used:
+ logger.info("Use same references")
+
+ return prompt_tokens, prompt_texts
+
+ def load_audio(self, reference_audio, sr):
+ """
+ Load the audio data from a file or bytes.
+ """
+ if len(reference_audio) > 255 or not Path(reference_audio).exists():
+ audio_data = reference_audio
+ reference_audio = io.BytesIO(audio_data)
+
+ waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
+
+ if waveform.shape[0] > 1:
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
+
+ if original_sr != sr:
+ resampler = torchaudio.transforms.Resample(
+ orig_freq=original_sr, new_freq=sr
+ )
+ waveform = resampler(waveform)
+
+ audio = waveform.squeeze().numpy()
+ return audio
diff --git a/tools/inference_engine/utils.py b/tools/inference_engine/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a6a5ae3545af8f157db9ec1f84a065b9bf058f5
--- /dev/null
+++ b/tools/inference_engine/utils.py
@@ -0,0 +1,39 @@
+import io
+import wave
+from dataclasses import dataclass
+from typing import Literal, Optional, Tuple
+
+import numpy as np
+
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+
+
+@dataclass
+class InferenceResult:
+ code: Literal["header", "segment", "error", "final"]
+ audio: Optional[Tuple[int, np.ndarray | bytes]]
+ error: Optional[Exception]
+
+
+def normalize_text(user_input: str, use_normalization: bool) -> str:
+ """Normalize user input text if needed."""
+ if use_normalization:
+ return ChnNormedText(raw_text=user_input).normalize()
+ else:
+ return user_input
+
+
+def wav_chunk_header(
+ sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
+) -> bytes:
+ buffer = io.BytesIO()
+
+ with wave.open(buffer, "wb") as wav_file:
+ wav_file.setnchannels(channels)
+ wav_file.setsampwidth(bit_depth // 8)
+ wav_file.setframerate(sample_rate)
+
+ wav_header_bytes = buffer.getvalue()
+ buffer.close()
+
+ return wav_header_bytes
diff --git a/tools/inference_engine/vq_manager.py b/tools/inference_engine/vq_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..07b5cb6d8cf85fa81813173606492441eb153463
--- /dev/null
+++ b/tools/inference_engine/vq_manager.py
@@ -0,0 +1,57 @@
+from typing import Callable
+
+import torch
+from loguru import logger
+
+from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+
+
+class VQManager:
+
+ def __init__(self):
+ # Make Pylance happy (attribut/method not defined...)
+ self.decoder_model: FireflyArchitecture
+ self.load_audio: Callable
+
+ def decode_vq_tokens(self, codes):
+ feature_lengths = torch.tensor(
+ [codes.shape[1]], device=self.decoder_model.device
+ )
+ logger.info(f"VQ features: {codes.shape}")
+
+ if isinstance(self.decoder_model, FireflyArchitecture):
+ return self.decoder_model.decode(
+ indices=codes[None],
+ feature_lengths=feature_lengths,
+ )[0].squeeze()
+
+ raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
+
+ def encode_reference(self, reference_audio, enable_reference_audio):
+ if enable_reference_audio and reference_audio is not None:
+ # Load audios, and prepare basic info here
+ reference_audio_content = self.load_audio(
+ reference_audio, self.decoder_model.spec_transform.sample_rate
+ )
+
+ audios = torch.from_numpy(reference_audio_content).to(
+ self.decoder_model.device
+ )[None, None, :]
+ audio_lengths = torch.tensor(
+ [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
+ )
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ if isinstance(self.decoder_model, FireflyArchitecture):
+ prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
+ logger.info(f"Encoded prompt: {prompt_tokens.shape}")
+ else:
+ raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
+ else:
+ prompt_tokens = None
+ logger.info("No reference audio provided")
+
+ return prompt_tokens
diff --git a/tools/llama/__pycache__/generate.cpython-310.pyc b/tools/llama/__pycache__/generate.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e655ea2b0fa084f8671a9b7d9ecfb8c5a7082d00
Binary files /dev/null and b/tools/llama/__pycache__/generate.cpython-310.pyc differ
diff --git a/tools/llama/build_dataset.py b/tools/llama/build_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc5ef120cce2e04b24f0f897e49f022cb1946c97
--- /dev/null
+++ b/tools/llama/build_dataset.py
@@ -0,0 +1,169 @@
+import itertools
+import os
+import re
+from collections import defaultdict
+from functools import partial
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import numpy as np
+from loguru import logger
+from tqdm import tqdm
+
+from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
+from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
+from tools.file import load_filelist
+
+# To avoid CPU overload
+os.environ["MKL_NUM_THREADS"] = "1"
+os.environ["OMP_NUM_THREADS"] = "1"
+
+
+def task_generator_folder(root: Path, text_extension: str):
+ files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
+ files = sorted(files)
+
+ grouped_files = defaultdict(list)
+ for file in tqdm(files, desc=f"Grouping {root}"):
+ p = str(file.parent)
+ speaker = file.parent.name
+
+ try:
+ if isinstance(text_extension, str):
+ texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
+ else:
+ texts = [
+ file.with_suffix(ext).read_text(encoding="utf-8")
+ for ext in text_extension
+ ]
+ except Exception as e:
+ logger.error(f"Failed to read text {file}: {e}")
+ continue
+
+ grouped_files[p].append((speaker, file, texts))
+
+ logger.info(
+ f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
+ )
+
+ for i in grouped_files.values():
+ subset = [(f, t) for _, f, t in i]
+ yield i[0][0], subset, "folder"
+
+
+def task_generator_filelist(filelist):
+ grouped_files = defaultdict(list)
+ for filename, speaker, _, text in load_filelist(filelist):
+ grouped_files[speaker].append((Path(filename), [text]))
+
+ logger.info(f"Found {len(grouped_files)} groups in {filelist}")
+ for speaker, values in grouped_files.items():
+ yield speaker, values, "filelist"
+
+
+def run_task(task):
+ name, subset, source = task
+
+ # Parse the files
+ sentences = []
+ for file, texts in subset:
+ np_file = file.with_suffix(".npy")
+ if np_file.exists() is False:
+ logger.warning(f"Can't find {np_file}")
+ continue
+
+ new_texts = []
+
+ for text in texts:
+ # Simple cleaning: replace { xxx } and < xxx > with space
+ text = re.sub(r"\{.*?\}", " ", text)
+ text = re.sub(r"<.*?>", " ", text)
+ text = re.sub(r"\s+", " ", text)
+ new_texts.append(text)
+
+ try:
+ semantics = np.load(np_file)
+ except Exception as e:
+ logger.error(f"Failed to parse {file}: {e}")
+ continue
+
+ if isinstance(semantics, np.ndarray):
+ semantics = semantics.tolist()
+
+ sentences.append(
+ Sentence(
+ texts=new_texts,
+ semantics=[Semantics(values=s) for s in semantics],
+ )
+ )
+
+ # Pack the sentences
+ return pack_pb_stream(
+ TextData(
+ source=source,
+ name=name,
+ sentences=sentences,
+ )
+ )
+
+
+@click.command()
+@click.option(
+ "--input",
+ type=click.Path(path_type=Path),
+ required=True,
+ help="A folder containing the dataset or a filelist",
+ multiple=True,
+)
+@click.option(
+ "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
+)
+@click.option("--num-workers", type=int, default=16)
+@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
+@click.option(
+ "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
+)
+def main(input, output, num_workers, text_extension, shard_size):
+ generator_fns = []
+
+ for f in input:
+ assert f.exists(), f"{f} not found"
+
+ if f.is_dir():
+ generator_fn = task_generator_folder(f, text_extension)
+ else:
+ generator_fn = task_generator_filelist(f)
+
+ generator_fns.append(generator_fn)
+
+ generator_fn = itertools.chain(*generator_fns)
+ output.mkdir(parents=True, exist_ok=True)
+
+ dataset_fp = None
+ tar_idx = 0
+ written_size = 0
+
+ with Pool(num_workers) as p:
+ for result in tqdm(p.imap_unordered(run_task, generator_fn)):
+ if dataset_fp is None:
+ dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
+
+ dataset_fp.write(result)
+ written_size += len(result)
+
+ if written_size > shard_size * 1024 * 1024:
+ logger.info(f"Finished writing {tar_idx} shards to {output}")
+ dataset_fp.close()
+ dataset_fp = None
+ written_size = 0
+ tar_idx += 1
+
+ if dataset_fp is not None:
+ dataset_fp.close()
+
+ logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/eval_in_context.py b/tools/llama/eval_in_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..30d70940487388185381246d8210a49a58e55743
--- /dev/null
+++ b/tools/llama/eval_in_context.py
@@ -0,0 +1,171 @@
+import pyrootutils
+import torch
+import torch.nn.functional as F
+from matplotlib import pyplot as plt
+from transformers import AutoTokenizer
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from torch.utils.data import DataLoader
+
+from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
+from tools.llama.generate import load_model
+
+
+def smooth(
+ scalars: list[float], weight: float
+) -> list[float]: # Weight between 0 and 1
+ last = scalars[0] # First value in the plot (first timestep)
+ smoothed = list()
+ for point in scalars:
+ smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
+ smoothed.append(smoothed_val) # Save it
+ last = smoothed_val # Anchor the last smoothed value
+
+ return smoothed
+
+
+@torch.inference_mode()
+def analyze_one_model(loader, config, weight, max_length):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ model = load_model(
+ config,
+ weight,
+ device,
+ torch.bfloat16,
+ max_length,
+ compile=False,
+ )[0]
+
+ current_step = 0
+ model.eval()
+
+ semantic_loss_sum = torch.zeros(
+ max_length,
+ dtype=torch.float32,
+ device=device,
+ )
+ counter = torch.zeros(
+ max_length,
+ dtype=torch.long,
+ device=device,
+ )
+
+ for batch in loader:
+ batch = {k: v.to(device) for k, v in batch.items()}
+
+ labels = batch["labels"]
+ outputs = model(
+ inp=batch["inputs"],
+ key_padding_mask=batch["attention_masks"],
+ )
+
+ token_logits = outputs.token_logits
+ codebook_logits = outputs.codebook_logits
+
+ # Generate labels
+ base_loss = F.cross_entropy(
+ token_logits.reshape(-1, token_logits.size(-1)),
+ labels[:, 0].reshape(-1),
+ ignore_index=-100,
+ reduction="none",
+ )
+
+ codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
+ semantic_loss = F.cross_entropy(
+ codebook_logits.reshape(-1, codebook_logits.size(-1)),
+ codebook_labels.reshape(-1),
+ ignore_index=-100,
+ reduction="none",
+ )
+
+ base_loss = base_loss.reshape(labels[:, 0].shape)
+ semantic_loss = semantic_loss.reshape(codebook_labels.shape)
+
+ semantic_loss_frame = semantic_loss.mean(-1)
+ pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
+
+ for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
+ semantic_loss_sum[~pad] += loss_sample[~pad]
+ counter[~pad] += 1
+
+ current_step += 1
+ if current_step == 10:
+ break
+
+ semantic_loss = semantic_loss.cpu()
+ counter = counter.cpu()
+ xs, ys = [], []
+
+ for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
+ if count > 0:
+ xs.append(i)
+ ys.append((loss / count).item()) # for better loss visualization
+
+ smoothed_ys = smooth(ys, 0.95)
+
+ # Unload model
+ del model
+ torch.cuda.empty_cache()
+
+ return xs, ys, smoothed_ys
+
+
+def main():
+ tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
+ max_length = 4096
+
+ ds = AutoAugTextDataset(
+ ["data/protos/sft/云天河"],
+ tokenizer=tokenizer,
+ use_speaker=False,
+ interactive_prob=1.0,
+ max_length=max_length,
+ )
+
+ loader = DataLoader(
+ ds,
+ batch_size=8,
+ collate_fn=TextDataCollator(tokenizer, max_length=max_length),
+ num_workers=0,
+ shuffle=False,
+ )
+
+ plt.figure(figsize=(10, 5), dpi=200)
+
+ plt.xlabel("Frame")
+ plt.ylabel("Loss")
+ plt.yscale("log")
+ plt.title("Semantic Loss")
+ plt.grid(which="both", axis="both")
+ plt.xlim(0, max_length)
+
+ tests = [
+ (
+ "pertrain-medium",
+ "dual_ar_2_codebook_medium",
+ "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
+ ),
+ (
+ "sft-medium",
+ "dual_ar_2_codebook_medium",
+ "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
+ ),
+ (
+ "sft-large",
+ "dual_ar_2_codebook_large",
+ "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
+ ),
+ ]
+
+ for name, config, weight in tests:
+ xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
+ plt.plot(xs, smoothed_ys, label=name)
+
+ plt.legend()
+ plt.savefig("semantic_loss.png")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/generate.py b/tools/llama/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..216890dcd1c1fab8b4d37894369fcc776fc7692f
--- /dev/null
+++ b/tools/llama/generate.py
@@ -0,0 +1,1115 @@
+import os
+import queue
+import threading
+import time
+from contextlib import nullcontext
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Literal, Optional, Tuple, Union
+
+import click
+import hydra
+import numpy as np
+import torch
+import torch._dynamo.config
+import torch._inductor.config
+from loguru import logger
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import (
+ CODEBOOK_PAD_TOKEN_ID,
+ Conversation,
+ Message,
+ TextPart,
+ VQPart,
+)
+from fish_speech.models.text2semantic.llama import BaseModelArgs
+from fish_speech.text import clean_text, split_text
+from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.triton.unique_kernel_names = True
+
+if hasattr(torch._inductor.config, "fx_graph_cache"):
+ # Experimental feature to reduce compilation times, will be on by default in future
+ torch._inductor.config.fx_graph_cache = True
+
+
+from torch.nn.attention import SDPBackend, sdpa_kernel
+
+from fish_speech.models.text2semantic.llama import (
+ BaseTransformer,
+ DualARTransformer,
+ NaiveTransformer,
+)
+
+
+def multinomial_sample_one_no_sync(
+ probs_sort,
+): # Does multinomial sampling without a cuda synchronization
+ q = torch.empty_like(probs_sort).exponential_(1)
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ temperature: torch.Tensor = 1.0,
+ top_p: torch.Tensor = 1.0,
+ repetition_penalty: torch.Tensor = 1.0,
+) -> torch.Tensor:
+ # Apply repetition penalty
+ if previous_tokens is not None:
+ previous_tokens = previous_tokens.long()
+ score = torch.gather(logits, dim=0, index=previous_tokens)
+ score = torch.where(
+ score < 0, score * repetition_penalty, score / repetition_penalty
+ )
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
+
+ # Apply top-p sampling
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cum_probs > top_p
+ sorted_indices_to_remove[0] = False # keep at least one option
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
+ )
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+ logits = logits / max(temperature, 1e-5)
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ return probs
+
+
+def multinomial_sample_one_no_sync_agent(
+ probs_sort,
+): # Does multinomial sampling without a cuda synchronization
+ q = torch.empty_like(probs_sort).exponential_(1)
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs_agent(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ temperature: torch.Tensor = 1.0,
+ top_p: torch.Tensor = 1.0,
+ repetition_penalty: torch.Tensor = 1.0,
+) -> torch.Tensor:
+ # Apply repetition penalty
+ if previous_tokens is not None:
+ previous_tokens = previous_tokens.long()
+ score = torch.gather(logits, dim=-1, index=previous_tokens)
+ score = torch.where(
+ score < 0, score * repetition_penalty, score / repetition_penalty
+ )
+ logits.scatter_(dim=-1, index=previous_tokens, src=score)
+
+ # Apply top-p sampling
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cum_probs > top_p
+ sorted_indices_to_remove[..., 0] = False # keep at least one option
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
+ )
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+ logits = logits / max(temperature, 1e-5)
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ return probs
+
+
+def sample(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ probs = logits_to_probs(
+ logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
+ )
+ idx_next = multinomial_sample_one_no_sync(probs)
+ return idx_next, probs
+
+
+def sample_agent(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ probs = logits_to_probs_agent(
+ logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
+ )
+ idx_next = multinomial_sample_one_no_sync_agent(probs)
+ return idx_next, probs
+
+
+def decode_one_token_ar_agent(
+ model: DualARTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ semantic_ids: list,
+ previous_tokens: torch.Tensor = None,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ # print(x, input_pos)
+ x = model.forward_generate(x, input_pos)
+ logits = x.logits # [:, -1:]
+ hidden_states = x.hidden_states # [:, -1:]
+
+ sampling_kwargs_main = sampling_kwargs.copy()
+ sampling_kwargs_main["temperature"] = 0.1
+ sampling_kwargs_main["top_p"] = 0.1
+ sampling_kwargs_main["repetition_penalty"] = 1.0
+
+ codebooks = [
+ sample_agent(
+ logits,
+ previous_tokens=None, # Disable repetition penalty for the token codebook
+ **sampling_kwargs_main,
+ )[0]
+ ]
+
+ # Cleanup the cache
+ for layer in model.fast_layers:
+ layer.attention.kv_cache.k_cache.fill_(0)
+ layer.attention.kv_cache.v_cache.fill_(0)
+
+ for codebook_idx in range(model.config.num_codebooks):
+ input_pos = torch.tensor(
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
+ )
+ logits = model.forward_generate_fast(hidden_states, input_pos)
+ a = sample_agent(
+ logits,
+ previous_tokens=(
+ previous_tokens[:, codebook_idx + 1]
+ if previous_tokens is not None
+ else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ hidden_states = model.fast_embeddings(a)
+ codebooks.append(a)
+
+ codebooks = torch.stack(codebooks, dim=1)
+ semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
+ codebooks[:, 1:, :] = torch.masked_fill(
+ codebooks[:, 1:, :],
+ ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
+ CODEBOOK_PAD_TOKEN_ID,
+ )
+
+ return codebooks
+
+
+def decode_one_token_naive_agent(
+ model: NaiveTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ semantic_ids: list,
+ previous_tokens: torch.Tensor = None,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ x = model.forward_generate(x, input_pos)
+
+ codebooks = [
+ sample(
+ x.token_logits,
+ previous_tokens=None, # Disable repetition penalty for the token codebook
+ **sampling_kwargs,
+ )[0]
+ ]
+
+ for i in range(model.config.num_codebooks):
+ codebooks.append(
+ sample_agent(
+ x.codebook_logits[:, :, i],
+ previous_tokens=(
+ previous_tokens[:, i + 1] if previous_tokens is not None else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ )
+
+ codebooks = torch.stack(codebooks, dim=1)
+ semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
+ codebooks[:, 1:, :] = torch.masked_fill(
+ codebooks[:, 1:, :],
+ ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
+ CODEBOOK_PAD_TOKEN_ID,
+ )
+
+ return codebooks
+
+
+def decode_one_token_ar(
+ model: DualARTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ semantic_ids: list,
+ previous_tokens: torch.Tensor = None,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ x = model.forward_generate(x, input_pos)
+
+ sampling_kwargs_main = sampling_kwargs.copy()
+ # sampling_kwargs_main["temperature"] = 0.1
+ # sampling_kwargs_main["top_p"] = 0.1
+ # sampling_kwargs_main["repetition_penalty"] = 1.0
+
+ codebooks = [
+ sample(
+ x.logits,
+ previous_tokens=(
+ previous_tokens[0] if previous_tokens is not None else None
+ ), # Disable repetition penalty for the token codebook
+ **sampling_kwargs_main,
+ )[0]
+ ]
+
+ hidden_states = x.hidden_states
+
+ # Cleanup the cache
+ for layer in model.fast_layers:
+ layer.attention.kv_cache.k_cache.fill_(0)
+ layer.attention.kv_cache.v_cache.fill_(0)
+
+ input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
+ model.forward_generate_fast(hidden_states, input_pos)
+ a = codebooks[0] - model.tokenizer.semantic_begin_id
+ a[a < 0] = 0
+ hidden_states = model.fast_embeddings(a)
+ codebooks.append(a)
+
+ for codebook_idx in range(1, model.config.num_codebooks):
+ input_pos = torch.tensor(
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
+ )
+ logits = model.forward_generate_fast(hidden_states, input_pos)
+ a = sample(
+ logits,
+ previous_tokens=(
+ previous_tokens[codebook_idx + 1]
+ if previous_tokens is not None
+ else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ hidden_states = model.fast_embeddings(a)
+ codebooks.append(a)
+
+ codebooks = torch.stack(codebooks, dim=0)
+ # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
+ # codebooks[1:, :] = torch.masked_fill(
+ # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
+ # )
+
+ # print(codebooks)
+ return codebooks
+
+
+def decode_one_token_naive(
+ model: NaiveTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ previous_tokens: torch.Tensor = None,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ x = model.forward_generate(x, input_pos)
+
+ sampling_kwargs_main = sampling_kwargs.copy()
+ sampling_kwargs_main["temperature"] = 0.1
+ sampling_kwargs_main["top_p"] = 0.1
+ sampling_kwargs_main["repetition_penalty"] = 1.0
+
+ codebooks = [
+ sample(
+ x.logits,
+ previous_tokens=None, # Disable repetition penalty for the token codebook
+ **sampling_kwargs_main,
+ )[0]
+ ]
+
+ for i in range(model.config.num_codebooks):
+ codebooks.append(
+ sample(
+ x.codebook_logits[:, :, i],
+ previous_tokens=(
+ previous_tokens[i + 1] if previous_tokens is not None else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ )
+
+ return torch.stack(codebooks, dim=0)
+
+
+def decode_n_tokens(
+ model: NaiveTransformer,
+ cur_token: torch.Tensor,
+ input_pos: torch.Tensor,
+ num_new_tokens: int,
+ semantic_ids: list,
+ decode_one_token=decode_one_token_naive,
+ **sampling_kwargs,
+):
+ previous_tokens = torch.zeros(
+ (model.config.num_codebooks + 1, model.config.max_seq_len),
+ dtype=torch.int,
+ device=cur_token.device,
+ )
+
+ for i in tqdm(range(num_new_tokens)):
+ # We need to get windowed repeat penalty
+ win_size = 16
+ if i < win_size:
+ window = previous_tokens[:, :win_size]
+ else:
+ window = previous_tokens[:, i - win_size : i]
+
+ with (
+ torch.backends.cuda.sdp_kernel(
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
+ )
+ if torch.cuda.is_available()
+ else nullcontext()
+ ): # Actually better for Inductor to codegen attention here
+ next_token = decode_one_token(
+ model=model,
+ x=cur_token,
+ input_pos=input_pos,
+ previous_tokens=window,
+ semantic_ids=semantic_ids,
+ **sampling_kwargs,
+ )
+
+ input_pos += 1
+ cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
+ previous_tokens[:, i : i + 1] = next_token.view(
+ model.config.num_codebooks + 1, -1
+ )
+
+ if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
+ break
+
+ return previous_tokens[:, : i + 1]
+
+
+@torch.no_grad()
+@torch.inference_mode()
+def generate(
+ *,
+ model: NaiveTransformer,
+ prompt: torch.Tensor,
+ max_new_tokens: int,
+ decode_one_token=decode_one_token_naive,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ """
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+ """
+
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ T = prompt.size(1)
+ # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
+ semantic_ids = [
+ model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
+ ]
+
+ if max_new_tokens:
+ if T + max_new_tokens > model.config.max_seq_len:
+ max_new_tokens = model.config.max_seq_len - T
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+ T_new = T + max_new_tokens
+ else:
+ T_new = model.config.max_seq_len
+ max_new_tokens = T_new - T
+
+ device, dtype = prompt.device, prompt.dtype
+
+ codebook_dim = 1 + model.config.num_codebooks
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ empty = torch.empty(
+ (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
+ )
+ empty[:, :T] = prompt
+ seq = empty
+ input_pos = torch.arange(0, T, device=device)
+
+ # Use non-accelerated version for now, to avoid compilation overhead
+ prefill_decode = (
+ decode_one_token_naive
+ if isinstance(model, NaiveTransformer)
+ else decode_one_token_ar
+ )
+
+ next_token = prefill_decode(
+ model,
+ prompt.view(1, codebook_dim, -1),
+ input_pos,
+ semantic_ids=semantic_ids,
+ **sampling_kwargs,
+ )
+ seq[:, T : T + 1] = next_token
+
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
+ x = decode_n_tokens(
+ model,
+ next_token.view(1, codebook_dim, -1),
+ input_pos,
+ max_new_tokens - 1,
+ decode_one_token=decode_one_token,
+ semantic_ids=semantic_ids,
+ **sampling_kwargs,
+ )
+ # x = torch.cat(generated_tokens, dim=1)
+ seq = seq[:, : T + 1 + x.size(1)]
+ seq[:, T + 1 :] = x
+
+ return seq
+
+
+def decode_n_tokens_agent(
+ model: NaiveTransformer,
+ cur_token: torch.Tensor,
+ input_pos: torch.Tensor,
+ num_new_tokens: int,
+ semantic_ids: list,
+ im_end_id: int = 4,
+ decode_one_token=decode_one_token_naive_agent,
+ early_stop_threshold: float = 0.6,
+ **sampling_kwargs,
+):
+ batch_size = cur_token.size(0)
+ previous_tokens = torch.zeros(
+ (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
+ dtype=torch.int,
+ device=cur_token.device,
+ )
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
+ finished = finished | (cur_token[:, 0, -1] == im_end_id)
+ start_time = time.time()
+
+ for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
+ # We need to get windowed repeat penalty
+ win_size = 16
+ if i < win_size:
+ window = previous_tokens[:, :, :win_size]
+ else:
+ window = previous_tokens[:, :, i - win_size : i]
+
+ with sdpa_kernel(
+ SDPBackend.MATH
+ ): # Actually better for Inductor to codegen attention here
+ next_token = decode_one_token(
+ model=model,
+ x=cur_token,
+ input_pos=input_pos,
+ previous_tokens=window,
+ semantic_ids=semantic_ids,
+ **sampling_kwargs,
+ )
+
+ input_pos += 1
+ cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
+ previous_tokens[:, :, i : i + 1] = next_token.view(
+ batch_size, model.config.num_codebooks + 1, -1
+ )
+
+ yield cur_token.cpu()
+
+ finished = finished | (cur_token[:, 0, -1] == im_end_id)
+ if finished.all() or (
+ 0 < early_stop_threshold < 1
+ and finished.sum() >= round(batch_size * early_stop_threshold)
+ ):
+ break
+
+ total_time = time.time() - start_time
+ generated_tokens = i + 1
+ tokens_per_second = (generated_tokens / total_time) * batch_size
+ logger.info(
+ f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
+ )
+
+
+@torch.no_grad()
+@torch.inference_mode()
+def generate_agent(
+ *,
+ model: BaseTransformer,
+ prompt: torch.Tensor,
+ max_new_tokens: int,
+ semantic_ids: list,
+ im_end_id: int = 4,
+ decode_one_token=decode_one_token_naive_agent,
+ num_samples: int = 1,
+ early_stop_threshold: float = 0.6,
+ **sampling_kwargs,
+):
+ """
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+ """
+
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ T = prompt.size(1)
+ prompt = prompt[None].repeat(num_samples, 1, 1)
+
+ if T >= model.config.max_seq_len:
+ raise ValueError(
+ f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
+ )
+
+ if max_new_tokens:
+ if T + max_new_tokens > model.config.max_seq_len:
+ max_new_tokens = model.config.max_seq_len - T
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+ T_new = T + max_new_tokens
+ else:
+ T_new = model.config.max_seq_len
+ max_new_tokens = T_new - T
+
+ device, dtype = prompt.device, prompt.dtype
+
+ codebook_dim = 1 + model.config.num_codebooks
+ input_pos = torch.arange(0, T, device=device)
+
+ # Use non-accelerated version for now, to avoid compilation overhead
+ prefill_decode = (
+ decode_one_token_naive_agent
+ if isinstance(model, NaiveTransformer)
+ else decode_one_token_ar_agent
+ )
+ next_token = prefill_decode(
+ model,
+ prompt,
+ input_pos,
+ semantic_ids=semantic_ids,
+ **sampling_kwargs,
+ ).view(num_samples, codebook_dim, -1)
+ yield next_token.cpu()
+
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
+
+ yield from decode_n_tokens_agent(
+ model,
+ next_token,
+ input_pos,
+ max_new_tokens - 1,
+ im_end_id=im_end_id,
+ semantic_ids=semantic_ids,
+ decode_one_token=decode_one_token,
+ early_stop_threshold=early_stop_threshold,
+ **sampling_kwargs,
+ )
+
+
+def encode_tokens(
+ tokenizer,
+ string,
+ device="cuda",
+ prompt_tokens=None,
+ num_codebooks=4,
+):
+ string = clean_text(string)
+
+ messages = []
+ messages.append(
+ Message(
+ role="user",
+ parts=[TextPart(text=string)],
+ cal_loss=False,
+ )
+ )
+
+ if prompt_tokens is not None:
+ if prompt_tokens.ndim == 3:
+ assert (
+ prompt_tokens.shape[0] == 1
+ ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
+ prompt_tokens = prompt_tokens[0]
+
+ assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
+
+ if prompt_tokens.shape[0] > num_codebooks:
+ logger.warning(
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
+ )
+ prompt_tokens = prompt_tokens[:num_codebooks]
+
+ vq_part = VQPart(codes=prompt_tokens.to(device))
+
+ messages.append(
+ Message(
+ role="assistant",
+ parts=[TextPart(text="<|voice|>"), vq_part],
+ cal_loss=False,
+ )
+ )
+ else:
+ messages.append(
+ Message(
+ role="assistant",
+ parts=[TextPart(text="<|voice|>")],
+ cal_loss=False,
+ add_im_end=False,
+ )
+ )
+
+ conversation = Conversation(messages=messages)
+ # conversation.visualize(tokenizer)
+ encoded = conversation.encode_for_inference(
+ tokenizer=tokenizer,
+ num_codebooks=num_codebooks,
+ )
+
+ return encoded.to(device)
+
+
+def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
+ model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
+ checkpoint_path, load_weights=True, is_agent=is_agent
+ )
+
+ model = model.to(device=device, dtype=precision)
+ logger.info(f"Restored model from checkpoint")
+
+ if isinstance(model, DualARTransformer):
+ decode_one_token = (
+ decode_one_token_ar_agent if is_agent else decode_one_token_ar
+ )
+ logger.info("Using DualARTransformer")
+ else:
+ decode_one_token = (
+ decode_one_token_naive_agent if is_agent else decode_one_token_naive
+ )
+ logger.info("Using NaiveTransformer")
+
+ if compile:
+ logger.info("Compiling function...")
+ decode_one_token = torch.compile(
+ decode_one_token,
+ fullgraph=True,
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
+ )
+
+ return model.eval(), decode_one_token
+
+
+@dataclass
+class GenerateResponse:
+ action: Literal["sample", "next"]
+ codes: Optional[torch.Tensor] = None
+ text: Optional[str] = None
+
+
+def generate_long(
+ *,
+ model,
+ device: str | torch.device,
+ decode_one_token: callable,
+ text: str,
+ num_samples: int = 1,
+ max_new_tokens: int = 0,
+ top_p: int = 0.7,
+ repetition_penalty: float = 1.5,
+ temperature: float = 0.7,
+ compile: bool = False,
+ iterative_prompt: bool = True,
+ max_length: int = 2048,
+ chunk_length: int = 150,
+ prompt_text: Optional[str | list[str]] = None,
+ prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
+):
+ assert 0 < top_p <= 1, "top_p must be in (0, 1]"
+ assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
+ assert 0 < temperature < 2, "temperature must be in (0, 2)"
+
+ use_prompt = prompt_text is not None and prompt_tokens is not None
+ if use_prompt and isinstance(prompt_text, str):
+ prompt_text = [prompt_text]
+ prompt_tokens = [prompt_tokens]
+
+ assert use_prompt is False or len(prompt_text) == len(
+ prompt_tokens
+ ), "Prompt text and tokens must have the same length"
+
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ tokenizer = model.tokenizer
+ im_end_id = tokenizer.get_token_id("<|im_end|>")
+
+ encoded = []
+ texts = split_text(text, chunk_length) if iterative_prompt else [text]
+ encoded_prompts = [
+ Conversation(
+ messages=[
+ Message(
+ role="system",
+ parts=[TextPart(text="Speak out the provided text.")],
+ cal_loss=False,
+ )
+ ]
+ )
+ .encode_for_inference(
+ tokenizer=tokenizer,
+ num_codebooks=model.config.num_codebooks,
+ )
+ .to(device)
+ ]
+
+ if use_prompt:
+ for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
+ encoded_prompts.append(
+ encode_tokens(
+ tokenizer,
+ string=t,
+ device=device,
+ prompt_tokens=c,
+ num_codebooks=model.config.num_codebooks,
+ )
+ )
+
+ for idx, text in enumerate(texts):
+ encoded.append(
+ encode_tokens(
+ tokenizer,
+ string=text,
+ device=device,
+ num_codebooks=model.config.num_codebooks,
+ )
+ )
+ logger.info(f"Encoded text: {text}")
+
+ # Move temperature, top_p, repetition_penalty to device
+ # This is important so that changing params doesn't trigger recompile
+ temperature = torch.tensor(temperature, device=device, dtype=torch.float)
+ top_p = torch.tensor(top_p, device=device, dtype=torch.float)
+ repetition_penalty = torch.tensor(
+ repetition_penalty, device=device, dtype=torch.float
+ )
+
+ for sample_idx in range(num_samples):
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ global_encoded = []
+ seg_idx = 0
+
+ while seg_idx < len(encoded):
+ logger.info(
+ f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
+ )
+
+ seg = encoded[seg_idx]
+ global_encoded.append(seg)
+
+ lengths = reversed([seg.size(1) for seg in global_encoded])
+
+ # Pick last 2000 tokens
+ count = 0
+ for i, length in enumerate(lengths):
+ count += length
+ if count + length > max_length - 1024 - sum(
+ t.shape[1] for t in encoded_prompts
+ ):
+ break
+
+ if i != 0 and i % 2 == 0:
+ i -= 1
+
+ # Rotate the list, always make sure first segment is included to avoid drift
+ if i < len(global_encoded) - 2:
+ partial_encoded = global_encoded[:2] + global_encoded[-i:]
+ else:
+ partial_encoded = global_encoded
+
+ if use_prompt:
+ partial_encoded = encoded_prompts + partial_encoded
+
+ cat_encoded = torch.cat(partial_encoded, dim=1)
+ prompt_length = cat_encoded.size(1)
+
+ t0 = time.perf_counter()
+ y = generate(
+ model=model,
+ prompt=cat_encoded,
+ max_new_tokens=max_new_tokens,
+ decode_one_token=decode_one_token,
+ temperature=temperature,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ )
+
+ if sample_idx == 0 and seg_idx == 0 and compile:
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ t = time.perf_counter() - t0
+
+ tokens_generated = y.size(1) - prompt_length
+ tokens_sec = tokens_generated / t
+ logger.info(
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
+ )
+ logger.info(
+ f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
+ )
+
+ if torch.cuda.is_available():
+ logger.info(
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
+ )
+
+ # Put the generated tokens
+ # since there is , we remove last token
+ codes = y[1:, prompt_length + 1 :].clone()
+ assert (codes >= 0).all(), f"Negative code found"
+
+ decoded = y[:, prompt_length:].clone()
+ # But for global encoding, we should keep the token
+
+ global_encoded.append(decoded)
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
+ yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
+ seg_idx += 1
+
+ # This indicates the end of the current sample
+ yield GenerateResponse(action="next")
+
+
+@dataclass
+class WrappedGenerateResponse:
+ status: Literal["success", "error"]
+ response: Optional[GenerateResponse | Exception] = None
+
+
+@dataclass
+class GenerateRequest:
+ request: dict
+ response_queue: queue.Queue
+
+
+def launch_thread_safe_queue(
+ checkpoint_path,
+ device,
+ precision,
+ compile: bool = False,
+):
+ input_queue = queue.Queue()
+ init_event = threading.Event()
+
+ def worker():
+ model, decode_one_token = load_model(
+ checkpoint_path, device, precision, compile=compile
+ )
+ with torch.device(device):
+ model.setup_caches(
+ max_batch_size=1,
+ max_seq_len=model.config.max_seq_len,
+ dtype=next(model.parameters()).dtype,
+ )
+ init_event.set()
+
+ while True:
+ item: GenerateRequest | None = input_queue.get()
+ if item is None:
+ break
+
+ kwargs = item.request
+ response_queue = item.response_queue
+
+ try:
+ for chunk in generate_long(
+ model=model, decode_one_token=decode_one_token, **kwargs
+ ):
+ response_queue.put(
+ WrappedGenerateResponse(status="success", response=chunk)
+ )
+ except Exception as e:
+ response_queue.put(WrappedGenerateResponse(status="error", response=e))
+
+ threading.Thread(target=worker, daemon=True).start()
+ init_event.wait()
+
+ return input_queue
+
+
+def launch_thread_safe_queue_agent(
+ checkpoint_path,
+ device,
+ precision,
+ compile: bool = False,
+):
+ input_queue = queue.Queue()
+ init_event = threading.Event()
+
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
+ config = BaseModelArgs.from_pretrained(checkpoint_path)
+
+ def worker():
+ model, decode_one_token = load_model(
+ checkpoint_path, device, precision, compile=compile, is_agent=True
+ )
+
+ with torch.device(device):
+ model.setup_caches(
+ max_batch_size=1,
+ max_seq_len=model.config.max_seq_len,
+ dtype=next(model.parameters()).dtype,
+ )
+ init_event.set()
+
+ while True:
+ item: GenerateRequest | None = input_queue.get()
+ if item is None:
+ break
+
+ kwargs = item.request
+ response_queue = item.response_queue
+
+ try:
+ for token in generate_agent(
+ model=model,
+ decode_one_token=decode_one_token,
+ **kwargs,
+ ):
+ response_queue.put(token)
+
+ response_queue.put("stop")
+ except Exception as e:
+ import traceback
+
+ logger.exception(f"Error in worker: {traceback.format_exc()}")
+ response_queue.put("error")
+
+ threading.Thread(target=worker, daemon=True).start()
+ init_event.wait()
+
+ return input_queue, tokenizer, config
+
+
+@click.command()
+@click.option(
+ "--text",
+ type=str,
+ default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+)
+@click.option("--prompt-text", type=str, default=None, multiple=True)
+@click.option(
+ "--prompt-tokens",
+ type=click.Path(path_type=Path, exists=True),
+ default=None,
+ multiple=True,
+)
+@click.option("--num-samples", type=int, default=1)
+@click.option("--max-new-tokens", type=int, default=0)
+@click.option("--top-p", type=float, default=0.7)
+@click.option("--repetition-penalty", type=float, default=1.2)
+@click.option("--temperature", type=float, default=0.7)
+@click.option(
+ "--checkpoint-path",
+ type=click.Path(path_type=Path, exists=True),
+ default="checkpoints/fish-speech-1.5",
+)
+@click.option("--device", type=str, default="cuda")
+@click.option("--compile/--no-compile", default=False)
+@click.option("--seed", type=int, default=42)
+@click.option("--half/--no-half", default=False)
+@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
+@click.option("--chunk-length", type=int, default=100)
+def main(
+ text: str,
+ prompt_text: Optional[list[str]],
+ prompt_tokens: Optional[list[Path]],
+ num_samples: int,
+ max_new_tokens: int,
+ top_p: int,
+ repetition_penalty: float,
+ temperature: float,
+ checkpoint_path: Path,
+ device: str,
+ compile: bool,
+ seed: int,
+ half: bool,
+ iterative_prompt: bool,
+ chunk_length: int,
+) -> None:
+
+ precision = torch.half if half else torch.bfloat16
+
+ if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
+ raise ValueError(
+ f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
+ )
+
+ logger.info("Loading model ...")
+ t0 = time.time()
+ model, decode_one_token = load_model(
+ checkpoint_path, device, precision, compile=compile
+ )
+ with torch.device(device):
+ model.setup_caches(
+ max_batch_size=1,
+ max_seq_len=model.config.max_seq_len,
+ dtype=next(model.parameters()).dtype,
+ )
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
+
+ if prompt_tokens is not None:
+ prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
+
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+
+ generator = generate_long(
+ model=model,
+ device=device,
+ decode_one_token=decode_one_token,
+ text=text,
+ num_samples=num_samples,
+ max_new_tokens=max_new_tokens,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ compile=compile,
+ iterative_prompt=iterative_prompt,
+ chunk_length=chunk_length,
+ prompt_text=prompt_text,
+ prompt_tokens=prompt_tokens,
+ )
+
+ idx = 0
+ codes = []
+
+ for response in generator:
+ if response.action == "sample":
+ codes.append(response.codes)
+ logger.info(f"Sampled text: {response.text}")
+ elif response.action == "next":
+ if codes:
+ np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
+ logger.info(f"Saved codes to codes_{idx}.npy")
+ logger.info(f"Next sample")
+ codes = []
+ idx += 1
+ else:
+ logger.error(f"Error: {response}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/merge_lora.py b/tools/llama/merge_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1bd3cbd725c4eccbe78f711d9718dfb278a6aa7
--- /dev/null
+++ b/tools/llama/merge_lora.py
@@ -0,0 +1,95 @@
+import shutil
+from copy import deepcopy
+from pathlib import Path
+
+import click
+import hydra
+import torch
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+
+from fish_speech.models.text2semantic.llama import BaseTransformer
+from fish_speech.models.text2semantic.lora import get_merged_state_dict
+
+
+@click.command()
+@click.option("--lora-config", type=str, default="r_8_alpha_16")
+@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
+@click.option("--lora-weight", type=str, required=True)
+@click.option("--output", type=str, required=True)
+def merge(lora_config, base_weight, lora_weight, output):
+ output = Path(output)
+ logger.info(
+ f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
+ )
+
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
+ cfg = compose(config_name=lora_config)
+
+ lora_config = instantiate(cfg)
+ logger.info(f"Loaded lora model with config {lora_config}")
+
+ llama_model = BaseTransformer.from_pretrained(
+ path=base_weight,
+ load_weights=True,
+ lora_config=lora_config,
+ )
+ logger.info(f"Loaded llama model")
+
+ llama_state_dict = llama_model.state_dict()
+ llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
+ llama_state_dict_copy = deepcopy(llama_state_dict)
+ lora_state_dict = torch.load(lora_weight, map_location="cpu")
+
+ if "state_dict" in llama_state_dict:
+ llama_state_dict = llama_state_dict["state_dict"]
+
+ if "state_dict" in lora_state_dict:
+ lora_state_dict = lora_state_dict["state_dict"]
+
+ # remove prefix model.
+ if any(k.startswith("model.") for k in llama_state_dict.keys()):
+ llama_state_dict = {
+ k.replace("model.", ""): v
+ for k, v in llama_state_dict.items()
+ if k.startswith("model.")
+ }
+ if any(k.startswith("model.") for k in lora_state_dict.keys()):
+ lora_state_dict = {
+ k.replace("model.", ""): v
+ for k, v in lora_state_dict.items()
+ if k.startswith("model.")
+ }
+
+ logger.info(f"Found {len(llama_state_dict)} keys in llama model")
+ logger.info(f"Found {len(lora_state_dict)} keys in lora model")
+
+ merged_state_dict = llama_state_dict | lora_state_dict
+ llama_model.load_state_dict(merged_state_dict, strict=True)
+ logger.info(f"Merged model loaded")
+
+ # Trigger eval mode to merge lora
+ llama_model.eval()
+ llama_model.save_pretrained(output, drop_lora=True)
+ logger.info(f"Saved merged model to {output}, validating")
+
+ new_state_dict = torch.load(output / "model.pth", map_location="cpu")
+ original_keys = set(llama_state_dict_copy.keys())
+ merged_keys = set(new_state_dict.keys())
+
+ assert original_keys == merged_keys, "Keys should be same"
+
+ for key in original_keys:
+ diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
+ if diff_l1 != 0:
+ break
+ else:
+ logger.error("Merged model is same as the original model")
+ exit(1)
+
+ logger.info("Merged model is different from the original model, check passed")
+
+
+if __name__ == "__main__":
+ merge()
diff --git a/tools/llama/quantize.py b/tools/llama/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..e629d944b5d1e262f6c0517480980fcac01dad86
--- /dev/null
+++ b/tools/llama/quantize.py
@@ -0,0 +1,497 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+import datetime
+import shutil
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import time
+from pathlib import Path
+
+import click
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fish_speech.models.text2semantic.llama import find_multiple
+from tools.llama.generate import load_model
+
+##### Quantization Primitives ######
+
+
+def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
+ # assumes symmetric quantization
+ # assumes axis == 0
+ # assumes dense memory format
+ # TODO(future): relax ^ as needed
+
+ # default setup for affine quantization of activations
+ eps = torch.finfo(torch.float32).eps
+
+ # get min and max
+ min_val, max_val = torch.aminmax(x, dim=1)
+
+ # calculate scales and zero_points based on min and max
+ # reference: https://fburl.com/code/srbiybme
+ min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+ max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+ device = min_val_neg.device
+
+ # reference: https://fburl.com/code/4wll53rk
+ max_val_pos = torch.max(-min_val_neg, max_val_pos)
+ scales = max_val_pos / (float(quant_max - quant_min) / 2)
+ # ensure scales is the same dtype as the original tensor
+ scales = torch.clamp(scales, min=eps).to(x.dtype)
+ zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+ # quantize based on qmin/qmax/scales/zp
+ # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
+ x_div = x / scales.unsqueeze(-1)
+ x_round = torch.round(x_div)
+ x_zp = x_round + zero_points.unsqueeze(-1)
+ quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
+
+ return quant, scales, zero_points
+
+
+def get_group_qparams(w, n_bit=4, groupsize=128):
+ # needed for GPTQ with padding
+ if groupsize > w.shape[-1]:
+ groupsize = w.shape[-1]
+ assert groupsize > 1
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ max_val = to_quant.amax(dim=1, keepdim=True)
+ min_val = to_quant.amin(dim=1, keepdim=True)
+ max_int = 2**n_bit - 1
+ scales = (max_val - min_val).clamp(min=1e-6) / max_int
+ zeros = min_val + scales * (2 ** (n_bit - 1))
+ return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
+ torch.bfloat16
+ ).reshape(w.shape[0], -1)
+
+
+def pack_scales_and_zeros(scales, zeros):
+ assert scales.shape == zeros.shape
+ assert scales.dtype == torch.bfloat16
+ assert zeros.dtype == torch.bfloat16
+ return (
+ torch.cat(
+ [
+ scales.reshape(scales.size(0), scales.size(1), 1),
+ zeros.reshape(zeros.size(0), zeros.size(1), 1),
+ ],
+ 2,
+ )
+ .transpose(0, 1)
+ .contiguous()
+ )
+
+
+def unpack_scales_and_zeros(scales_and_zeros):
+ assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
+ assert scales_and_zeros.dtype == torch.float
+ return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
+
+
+def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
+ assert groupsize > 1
+ # needed for GPTQ single column quantize
+ if groupsize > w.shape[-1] and scales.shape[-1] == 1:
+ groupsize = w.shape[-1]
+
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ scales = scales.reshape(-1, 1)
+ zeros = zeros.reshape(-1, 1)
+ min_val = zeros - scales * (2 ** (n_bit - 1))
+ max_int = 2**n_bit - 1
+ min_int = 0
+ w_int32 = (
+ to_quant.sub(min_val)
+ .div(scales)
+ .round()
+ .clamp_(min_int, max_int)
+ .to(torch.int32)
+ .reshape_as(w)
+ )
+
+ return w_int32
+
+
+def group_quantize_tensor(w, n_bit=4, groupsize=128):
+ scales, zeros = get_group_qparams(w, n_bit, groupsize)
+ w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
+ scales_and_zeros = pack_scales_and_zeros(scales, zeros)
+ return w_int32, scales_and_zeros
+
+
+def group_dequantize_tensor_from_qparams(
+ w_int32, scales, zeros, n_bit=4, groupsize=128
+):
+ assert groupsize > 1
+ # needed for GPTQ single column dequantize
+ if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
+ groupsize = w_int32.shape[-1]
+ assert w_int32.shape[-1] % groupsize == 0
+ assert w_int32.dim() == 2
+
+ w_int32_grouped = w_int32.reshape(-1, groupsize)
+ scales = scales.reshape(-1, 1)
+ zeros = zeros.reshape(-1, 1)
+
+ w_dq = (
+ w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
+ )
+ return w_dq
+
+
+def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
+ scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
+ return group_dequantize_tensor_from_qparams(
+ w_int32, scales, zeros, n_bit, groupsize
+ )
+
+
+class QuantHandler:
+ def __init__(self, mod):
+ self.mod = mod
+
+ def create_quantized_state_dict(self) -> "StateDict":
+ pass
+
+ def convert_for_runtime(self) -> "nn.Module":
+ pass
+
+
+##### Weight-only int8 per-channel quantized code ######
+
+
+def replace_linear_weight_only_int8_per_channel(module):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ setattr(
+ module,
+ name,
+ WeightOnlyInt8Linear(child.in_features, child.out_features),
+ )
+ else:
+ replace_linear_weight_only_int8_per_channel(child)
+
+
+class WeightOnlyInt8QuantHandler:
+ def __init__(self, mod):
+ self.mod = mod
+
+ @torch.no_grad()
+ def create_quantized_state_dict(self):
+ cur_state_dict = self.mod.state_dict()
+ for fqn, mod in self.mod.named_modules():
+ if isinstance(mod, torch.nn.Linear):
+ int8_weight, scales, _ = dynamically_quantize_per_channel(
+ mod.weight.float(), -128, 127, torch.int8
+ )
+ cur_state_dict[f"{fqn}.weight"] = int8_weight
+ cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
+
+ return cur_state_dict
+
+ def convert_for_runtime(self):
+ replace_linear_weight_only_int8_per_channel(self.mod)
+ return self.mod
+
+
+class WeightOnlyInt8Linear(torch.nn.Module):
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ weight: torch.Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.register_buffer(
+ "weight", torch.empty((out_features, in_features), dtype=torch.int8)
+ )
+ self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
+
+
+##### weight only int4 per channel groupwise quantized code ######
+
+
+def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
+ weight_int32, scales_and_zeros = group_quantize_tensor(
+ weight_bf16, n_bit=4, groupsize=groupsize
+ )
+ weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
+ weight_int32, inner_k_tiles
+ )
+ return weight_int4pack, scales_and_zeros
+
+
+def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
+ origin_x_size = x.size()
+ x = x.reshape(-1, origin_x_size[-1])
+ c = torch.ops.aten._weight_int4pack_mm(
+ x, weight_int4pack, groupsize, scales_and_zeros
+ )
+ new_shape = origin_x_size[:-1] + (out_features,)
+ c = c.reshape(new_shape)
+ return c
+
+
+def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
+ return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
+
+
+def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
+ setattr(
+ module,
+ name,
+ WeightOnlyInt4Linear(
+ child.in_features,
+ child.out_features,
+ bias=False,
+ groupsize=groupsize,
+ inner_k_tiles=inner_k_tiles,
+ padding=False,
+ ),
+ )
+ elif padding:
+ setattr(
+ module,
+ name,
+ WeightOnlyInt4Linear(
+ child.in_features,
+ child.out_features,
+ bias=False,
+ groupsize=groupsize,
+ inner_k_tiles=inner_k_tiles,
+ padding=True,
+ ),
+ )
+ else:
+ replace_linear_int4(child, groupsize, inner_k_tiles, padding)
+
+
+class WeightOnlyInt4QuantHandler:
+ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
+ self.mod = mod
+ self.groupsize = groupsize
+ self.inner_k_tiles = inner_k_tiles
+ self.padding = padding
+ assert groupsize in [32, 64, 128, 256]
+ assert inner_k_tiles in [2, 4, 8]
+
+ @torch.no_grad()
+ def create_quantized_state_dict(self):
+ cur_state_dict = self.mod.state_dict()
+ for fqn, mod in self.mod.named_modules():
+ if isinstance(mod, torch.nn.Linear):
+ assert not mod.bias
+ out_features = mod.out_features
+ in_features = mod.in_features
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
+ print(f"linear: {fqn}, in={in_features}, out={out_features}")
+
+ weight = mod.weight.data
+ if not _check_linear_int4_k(
+ in_features, self.groupsize, self.inner_k_tiles
+ ):
+ if self.padding:
+ import torch.nn.functional as F
+
+ print(
+ f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
+ )
+ padded_in_features = find_multiple(in_features, 1024)
+ weight = F.pad(
+ weight, pad=(0, padded_in_features - in_features)
+ )
+ else:
+ print(
+ f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ + "and that groupsize and inner_k_tiles*16 evenly divide into it"
+ )
+ continue
+ (
+ weight_int4pack,
+ scales_and_zeros,
+ ) = prepare_int4_weight_and_scales_and_zeros(
+ weight.to(torch.bfloat16).to("cuda"),
+ self.groupsize,
+ self.inner_k_tiles,
+ )
+ cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
+ cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
+
+ return cur_state_dict
+
+ def convert_for_runtime(self):
+ replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
+ return self.mod
+
+
+class WeightOnlyInt4Linear(torch.nn.Module):
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ weight: torch.Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias=True,
+ device=None,
+ dtype=None,
+ groupsize: int = 128,
+ inner_k_tiles: int = 8,
+ padding: bool = True,
+ ) -> None:
+ super().__init__()
+ self.padding = padding
+ if padding:
+ self.origin_in_features = in_features
+ in_features = find_multiple(in_features, 1024)
+
+ self.in_features = in_features
+ self.out_features = out_features
+ assert not bias, "require bias=False"
+ self.groupsize = groupsize
+ self.inner_k_tiles = inner_k_tiles
+
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
+ assert (
+ in_features % (inner_k_tiles * 16) == 0
+ ), "require in_features % (innerKTiles * 16) == 0"
+ self.register_buffer(
+ "weight",
+ torch.empty(
+ (
+ out_features // 8,
+ in_features // (inner_k_tiles * 16),
+ 32,
+ inner_k_tiles // 2,
+ ),
+ dtype=torch.int32,
+ ),
+ )
+ self.register_buffer(
+ "scales_and_zeros",
+ torch.empty(
+ (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
+ ),
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ input = input.to(torch.bfloat16)
+ if self.padding:
+ import torch.nn.functional as F
+
+ input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
+ return linear_forward_int4(
+ input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
+ )
+
+
+def generate_folder_name():
+ now = datetime.datetime.now()
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
+ return folder_name
+
+
+@click.command()
+@click.option(
+ "--checkpoint-path",
+ type=click.Path(path_type=Path, exists=True),
+ default="checkpoints/fish-speech-1.4",
+)
+@click.option(
+ "--mode", type=str, default="int8", help="type of quantization to perform"
+)
+@click.option(
+ "--groupsize", type=int, default=128, help="Group size for int4 quantization."
+)
+@click.option("--timestamp", type=str, default="None", help="When to do quantization")
+def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
+
+ device = "cpu"
+ precision = torch.bfloat16
+
+ print("Loading model ...")
+ t0 = time.time()
+
+ model, _ = load_model(
+ checkpoint_path=checkpoint_path,
+ device=device,
+ precision=precision,
+ compile=False,
+ )
+ vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+ now = timestamp if timestamp != "None" else generate_folder_name()
+
+ if mode == "int8":
+ print(
+ "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
+ )
+ quant_handler = WeightOnlyInt8QuantHandler(model)
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+ dir_name = checkpoint_path
+ dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
+ if (dst_name / vq_model).exists():
+ (dst_name / vq_model).unlink()
+ quantize_path = dst_name / "model.pth"
+
+ elif mode == "int4":
+ print(
+ "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
+ )
+ quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+ dir_name = checkpoint_path
+ dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
+ if (dst_name / vq_model).exists():
+ (dst_name / vq_model).unlink()
+ quantize_path = dst_name / "model.pth"
+
+ else:
+ raise ValueError(
+ f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
+ )
+
+ print(f"Writing quantized weights to {quantize_path}")
+ quantize_path.unlink(missing_ok=True) # remove existing file if one already there
+ torch.save(quantized_state_dict, quantize_path)
+ print(f"Quantization complete took {time.time() - t0:.02f} seconds")
+
+
+if __name__ == "__main__":
+ quantize()
diff --git a/tools/llama/rebuild_tokenizer.py b/tools/llama/rebuild_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea64fa6788833000c8dc41e3d570dd5b250fb14b
--- /dev/null
+++ b/tools/llama/rebuild_tokenizer.py
@@ -0,0 +1,57 @@
+from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+
+# Initialize a tokenizer
+tokenizer = Tokenizer(models.BPE())
+
+# Customize pre-tokenization and decoding
+tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
+tokenizer.decoder = decoders.ByteLevel()
+tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
+
+# Don't train the tokenizer
+trainer = trainers.BpeTrainer(
+ vocab_size=0,
+ min_frequency=2,
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
+ special_tokens=[
+ "<|begin_of_sequence|>",
+ "<|end_of_sequence|>",
+ "<|im_start|>",
+ "<|im_sep|>", # system, user, assistant, etc.
+ "<|im_end|>",
+ "<|semantic|>", # audio features
+ "<|pad|>",
+ ],
+)
+
+# <|im_start|>user<|im_sep|>...<|im_end|>
+# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
+tokenizer.train_from_iterator([], trainer=trainer)
+
+print(len(tokenizer.get_vocab()))
+x = tokenizer.encode(
+ "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
+).ids
+print(x, len(x))
+print(tokenizer.decode(x, skip_special_tokens=True))
+
+
+tokenizer = PreTrainedTokenizerFast(
+ tokenizer_object=tokenizer,
+ pad_token="<|pad|>",
+ bos_token="<|begin_of_sequence|>",
+ eos_token="<|end_of_sequence|>",
+)
+
+# Try tokenizing a new sequence
+sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
+encoded = tokenizer(sequence).input_ids
+
+print("Test encoding....")
+print(f"\tSentence: {sequence}")
+print(f"\tEncoded: {encoded}")
+print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
+print(f"\tDecoded: {tokenizer.decode(encoded)}")
+
+tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
diff --git a/tools/run_webui.py b/tools/run_webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a469e466b87612c968187cb07a5a977e35284f2
--- /dev/null
+++ b/tools/run_webui.py
@@ -0,0 +1,104 @@
+import os
+from argparse import ArgumentParser
+from pathlib import Path
+
+import pyrootutils
+import torch
+from loguru import logger
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from tools.inference_engine import TTSInferenceEngine
+from tools.llama.generate import launch_thread_safe_queue
+from tools.schema import ServeTTSRequest
+from tools.vqgan.inference import load_model as load_decoder_model
+from tools.webui import build_app
+from tools.webui.inference import get_inference_wrapper
+
+# Make einx happy
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.5",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-gradio-length", type=int, default=0)
+ parser.add_argument("--theme", type=str, default="light")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ args.precision = torch.half if args.half else torch.bfloat16
+
+ # Check if MPS or CUDA is available
+ if torch.backends.mps.is_available():
+ args.device = "mps"
+ logger.info("mps is available, running on mps.")
+ elif not torch.cuda.is_available():
+ logger.info("CUDA is not available, running on CPU.")
+ args.device = "cpu"
+
+ logger.info("Loading Llama model...")
+ llama_queue = launch_thread_safe_queue(
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ compile=args.compile,
+ )
+
+ logger.info("Loading VQ-GAN model...")
+ decoder_model = load_decoder_model(
+ config_name=args.decoder_config_name,
+ checkpoint_path=args.decoder_checkpoint_path,
+ device=args.device,
+ )
+
+ logger.info("Decoder model loaded, warming up...")
+
+ # Create the inference engine
+ inference_engine = TTSInferenceEngine(
+ llama_queue=llama_queue,
+ decoder_model=decoder_model,
+ compile=args.compile,
+ precision=args.precision,
+ )
+
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
+ list(
+ inference_engine.inference(
+ ServeTTSRequest(
+ text="Hello world.",
+ references=[],
+ reference_id=None,
+ max_new_tokens=1024,
+ chunk_length=200,
+ top_p=0.7,
+ repetition_penalty=1.5,
+ temperature=0.7,
+ format="wav",
+ )
+ )
+ )
+
+ logger.info("Warming up done, launching the web UI...")
+
+ # Get the inference function with the immutable arguments
+ inference_fct = get_inference_wrapper(inference_engine)
+
+ app = build_app(inference_fct, args.theme)
+ app.launch(show_api=True, share=True)
diff --git a/tools/schema.py b/tools/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ce916005003ed631b9f9695940b3baa7a3223ad
--- /dev/null
+++ b/tools/schema.py
@@ -0,0 +1,170 @@
+import os
+import queue
+from dataclasses import dataclass
+from typing import Annotated, Literal
+
+import torch
+from pydantic import BaseModel, Field, conint, conlist
+from pydantic.functional_validators import SkipValidation
+
+from fish_speech.conversation import Message, TextPart, VQPart
+
+
+class ServeVQPart(BaseModel):
+ type: Literal["vq"] = "vq"
+ codes: SkipValidation[list[list[int]]]
+
+
+class ServeTextPart(BaseModel):
+ type: Literal["text"] = "text"
+ text: str
+
+
+class ServeAudioPart(BaseModel):
+ type: Literal["audio"] = "audio"
+ audio: bytes
+
+
+@dataclass
+class ASRPackRequest:
+ audio: torch.Tensor
+ result_queue: queue.Queue
+ language: str
+
+
+class ServeASRRequest(BaseModel):
+ # The audio should be an uncompressed PCM float16 audio
+ audios: list[bytes]
+ sample_rate: int = 44100
+ language: Literal["zh", "en", "ja", "auto"] = "auto"
+
+
+class ServeASRTranscription(BaseModel):
+ text: str
+ duration: float
+ huge_gap: bool
+
+
+class ServeASRSegment(BaseModel):
+ text: str
+ start: float
+ end: float
+
+
+class ServeTimedASRResponse(BaseModel):
+ text: str
+ segments: list[ServeASRSegment]
+ duration: float
+
+
+class ServeASRResponse(BaseModel):
+ transcriptions: list[ServeASRTranscription]
+
+
+class ServeMessage(BaseModel):
+ role: Literal["system", "assistant", "user"]
+ parts: list[ServeVQPart | ServeTextPart]
+
+ def to_conversation_message(self):
+ new_message = Message(role=self.role, parts=[])
+ if self.role == "assistant":
+ new_message.modality = "voice"
+
+ for part in self.parts:
+ if isinstance(part, ServeTextPart):
+ new_message.parts.append(TextPart(text=part.text))
+ elif isinstance(part, ServeVQPart):
+ new_message.parts.append(
+ VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
+ )
+ else:
+ raise ValueError(f"Unsupported part type: {part}")
+
+ return new_message
+
+
+class ServeChatRequest(BaseModel):
+ messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
+ max_new_tokens: int = 1024
+ top_p: float = 0.7
+ repetition_penalty: float = 1.2
+ temperature: float = 0.7
+ streaming: bool = False
+ num_samples: int = 1
+ early_stop_threshold: float = 1.0
+
+
+class ServeVQGANEncodeRequest(BaseModel):
+ # The audio here should be in wav, mp3, etc
+ audios: list[bytes]
+
+
+class ServeVQGANEncodeResponse(BaseModel):
+ tokens: SkipValidation[list[list[list[int]]]]
+
+
+class ServeVQGANDecodeRequest(BaseModel):
+ tokens: SkipValidation[list[list[list[int]]]]
+
+
+class ServeVQGANDecodeResponse(BaseModel):
+ # The audio here should be in PCM float16 format
+ audios: list[bytes]
+
+
+class ServeForwardMessage(BaseModel):
+ role: str
+ content: str
+
+
+class ServeResponse(BaseModel):
+ messages: list[ServeMessage]
+ finish_reason: Literal["stop", "error"] | None = None
+ stats: dict[str, int | float | str] = {}
+
+
+class ServeStreamDelta(BaseModel):
+ role: Literal["system", "assistant", "user"] | None = None
+ part: ServeVQPart | ServeTextPart | None = None
+
+
+class ServeStreamResponse(BaseModel):
+ sample_id: int = 0
+ delta: ServeStreamDelta | None = None
+ finish_reason: Literal["stop", "error"] | None = None
+ stats: dict[str, int | float | str] | None = None
+
+
+class ServeReferenceAudio(BaseModel):
+ audio: bytes
+ text: str
+
+ def __repr__(self) -> str:
+ return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
+
+
+class ServeTTSRequest(BaseModel):
+ text: str
+ chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
+ # Audio format
+ format: Literal["wav", "pcm", "mp3"] = "wav"
+ # References audios for in-context learning
+ references: list[ServeReferenceAudio] = []
+ # Reference id
+ # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
+ # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
+ reference_id: str | None = None
+ seed: int | None = None
+ use_memory_cache: Literal["on", "off"] = "off"
+ # Normalize text for en & zh, this increase stability for numbers
+ normalize: bool = True
+ # not usually used below
+ streaming: bool = False
+ max_new_tokens: int = 1024
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+
+ class Config:
+ # Allow arbitrary types for pytorch related types
+ arbitrary_types_allowed = True
diff --git a/tools/sensevoice/README.md b/tools/sensevoice/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9a2078aa2d96dfafb445384316f2041d9e819e63
--- /dev/null
+++ b/tools/sensevoice/README.md
@@ -0,0 +1,59 @@
+# FunASR Command Line Interface
+
+This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files.
+
+## Requirements
+
+- Python >= 3.10
+- PyTorch <= 2.3.1
+- ffmpeg, pydub, audio-separator[gpu].
+
+## Installation
+
+Install the required packages:
+
+```bash
+pip install -e .[stable]
+```
+
+Make sure you have `ffmpeg` installed and available in your `PATH`.
+
+## Usage
+
+### Basic Usage
+
+To run the tool with default settings:
+
+```bash
+python tools/sensevoice/fun_asr.py --audio-dir --save-dir
+```
+
+## Options
+
+| Option | Description |
+| :-----------------------: | :---------------------------------------------------------------------------: |
+| --audio-dir | Directory containing audio or video files. |
+| --save-dir | Directory to save processed audio files. |
+| --device | Device to use for processing. Options: cuda (default) or cpu. |
+| --language | Language of the transcription. Default is auto. |
+| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
+| --punc | Enable punctuation prediction. |
+| --denoise | Enable noise reduction (vocal separation). |
+
+## Example
+
+To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
+
+```bash
+python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
+```
+
+## Additional Notes
+
+- The tool supports `both audio and video files`. Videos will be converted to audio automatically.
+- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
+- The script will automatically create necessary directories in the `--save-dir`.
+
+## Troubleshooting
+
+If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
diff --git a/tools/sensevoice/__init__.py b/tools/sensevoice/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tools/sensevoice/auto_model.py b/tools/sensevoice/auto_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd2e186617fe889500d01d95eccdafc5c0248b84
--- /dev/null
+++ b/tools/sensevoice/auto_model.py
@@ -0,0 +1,573 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import copy
+import json
+import logging
+import os.path
+import random
+import re
+import string
+import time
+
+import numpy as np
+import torch
+from funasr.download.download_model_from_hub import download_model
+from funasr.download.file import download_from_url
+from funasr.register import tables
+from funasr.train_utils.load_pretrained_model import load_pretrained_model
+from funasr.train_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import export_utils, misc
+from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
+from funasr.utils.misc import deep_update
+from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
+from tqdm import tqdm
+
+from .vad_utils import merge_vad, slice_padding_audio_samples
+
+try:
+ from funasr.models.campplus.cluster_backend import ClusterBackend
+ from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
+except:
+ pass
+
+
+def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
+ """ """
+ data_list = []
+ key_list = []
+ filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
+
+ chars = string.ascii_letters + string.digits
+ if isinstance(data_in, str):
+ if data_in.startswith("http://") or data_in.startswith("https://"): # url
+ data_in = download_from_url(data_in)
+
+ if isinstance(data_in, str) and os.path.exists(
+ data_in
+ ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
+ _, file_extension = os.path.splitext(data_in)
+ file_extension = file_extension.lower()
+ if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
+ with open(data_in, encoding="utf-8") as fin:
+ for line in fin:
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ if data_in.endswith(
+ ".jsonl"
+ ): # file.jsonl: json.dumps({"source": data})
+ lines = json.loads(line.strip())
+ data = lines["source"]
+ key = data["key"] if "key" in data else key
+ else: # filelist, wav.scp, text.txt: id \t data or data
+ lines = line.strip().split(maxsplit=1)
+ data = lines[1] if len(lines) > 1 else lines[0]
+ key = lines[0] if len(lines) > 1 else key
+
+ data_list.append(data)
+ key_list.append(key)
+ else:
+ if key is None:
+ # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ key = misc.extract_filename_without_extension(data_in)
+ data_list = [data_in]
+ key_list = [key]
+ elif isinstance(data_in, (list, tuple)):
+ if data_type is not None and isinstance(
+ data_type, (list, tuple)
+ ): # mutiple inputs
+ data_list_tmp = []
+ for data_in_i, data_type_i in zip(data_in, data_type):
+ key_list, data_list_i = prepare_data_iterator(
+ data_in=data_in_i, data_type=data_type_i
+ )
+ data_list_tmp.append(data_list_i)
+ data_list = []
+ for item in zip(*data_list_tmp):
+ data_list.append(item)
+ else:
+ # [audio sample point, fbank, text]
+ data_list = data_in
+ key_list = []
+ for data_i in data_in:
+ if isinstance(data_i, str) and os.path.exists(data_i):
+ key = misc.extract_filename_without_extension(data_i)
+ else:
+ if key is None:
+ key = "rand_key_" + "".join(
+ random.choice(chars) for _ in range(13)
+ )
+ key_list.append(key)
+
+ else: # raw text; audio sample point, fbank; bytes
+ if isinstance(data_in, bytes): # audio bytes
+ data_in = load_bytes(data_in)
+ if key is None:
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ data_list = [data_in]
+ key_list = [key]
+
+ return key_list, data_list
+
+
+class AutoModel:
+
+ def __init__(self, **kwargs):
+
+ try:
+ from funasr.utils.version_checker import check_for_update
+
+ print(
+ "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
+ )
+ check_for_update(disable=kwargs.get("disable_update", False))
+ except:
+ pass
+
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+ logging.basicConfig(level=log_level)
+
+ model, kwargs = self.build_model(**kwargs)
+
+ # if vad_model is not None, build vad model else None
+ vad_model = kwargs.get("vad_model", None)
+ vad_kwargs = (
+ {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
+ )
+ if vad_model is not None:
+ logging.info("Building VAD model.")
+ vad_kwargs["model"] = vad_model
+ vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
+ vad_kwargs["device"] = kwargs["device"]
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
+
+ # if punc_model is not None, build punc model else None
+ punc_model = kwargs.get("punc_model", None)
+ punc_kwargs = (
+ {}
+ if kwargs.get("punc_kwargs", {}) is None
+ else kwargs.get("punc_kwargs", {})
+ )
+ if punc_model is not None:
+ logging.info("Building punc model.")
+ punc_kwargs["model"] = punc_model
+ punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
+ punc_kwargs["device"] = kwargs["device"]
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
+
+ # if spk_model is not None, build spk model else None
+ spk_model = kwargs.get("spk_model", None)
+ spk_kwargs = (
+ {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
+ )
+ if spk_model is not None:
+ logging.info("Building SPK model.")
+ spk_kwargs["model"] = spk_model
+ spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
+ spk_kwargs["device"] = kwargs["device"]
+ spk_model, spk_kwargs = self.build_model(**spk_kwargs)
+ self.cb_model = ClusterBackend().to(kwargs["device"])
+ spk_mode = kwargs.get("spk_mode", "punc_segment")
+ if spk_mode not in ["default", "vad_segment", "punc_segment"]:
+ logging.error(
+ "spk_mode should be one of default, vad_segment and punc_segment."
+ )
+ self.spk_mode = spk_mode
+
+ self.kwargs = kwargs
+ self.model = model
+ self.vad_model = vad_model
+ self.vad_kwargs = vad_kwargs
+ self.punc_model = punc_model
+ self.punc_kwargs = punc_kwargs
+ self.spk_model = spk_model
+ self.spk_kwargs = spk_kwargs
+ self.model_path = kwargs.get("model_path")
+
+ @staticmethod
+ def build_model(**kwargs):
+ assert "model" in kwargs
+ if "model_conf" not in kwargs:
+ logging.info(
+ "download models from model hub: {}".format(kwargs.get("hub", "ms"))
+ )
+ kwargs = download_model(**kwargs)
+
+ set_all_random_seed(kwargs.get("seed", 0))
+
+ device = kwargs.get("device", "cuda")
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
+ device = "cpu"
+ kwargs["batch_size"] = 1
+ kwargs["device"] = device
+
+ torch.set_num_threads(kwargs.get("ncpu", 4))
+
+ # build tokenizer
+ tokenizer = kwargs.get("tokenizer", None)
+ if tokenizer is not None:
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
+ tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
+ kwargs["token_list"] = (
+ tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+ )
+ kwargs["token_list"] = (
+ tokenizer.get_vocab()
+ if hasattr(tokenizer, "get_vocab")
+ else kwargs["token_list"]
+ )
+ vocab_size = (
+ len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
+ )
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
+ vocab_size = tokenizer.get_vocab_size()
+ else:
+ vocab_size = -1
+ kwargs["tokenizer"] = tokenizer
+
+ # build frontend
+ frontend = kwargs.get("frontend", None)
+ kwargs["input_size"] = None
+ if frontend is not None:
+ frontend_class = tables.frontend_classes.get(frontend)
+ frontend = frontend_class(**kwargs.get("frontend_conf", {}))
+ kwargs["input_size"] = (
+ frontend.output_size() if hasattr(frontend, "output_size") else None
+ )
+ kwargs["frontend"] = frontend
+ # build model
+ model_class = tables.model_classes.get(kwargs["model"])
+ assert model_class is not None, f'{kwargs["model"]} is not registered'
+ model_conf = {}
+ deep_update(model_conf, kwargs.get("model_conf", {}))
+ deep_update(model_conf, kwargs)
+ model = model_class(**model_conf, vocab_size=vocab_size)
+
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None:
+ if os.path.exists(init_param):
+ logging.info(f"Loading pretrained params from {init_param}")
+ load_pretrained_model(
+ model=model,
+ path=init_param,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ scope_map=kwargs.get("scope_map", []),
+ excludes=kwargs.get("excludes", None),
+ )
+ else:
+ print(f"error, init_param does not exist!: {init_param}")
+
+ # fp16
+ if kwargs.get("fp16", False):
+ model.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ model.to(torch.bfloat16)
+ model.to(device)
+
+ if not kwargs.get("disable_log", True):
+ tables.print()
+
+ return model, kwargs
+
+ def __call__(self, *args, **cfg):
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ res = self.model(*args, kwargs)
+ return res
+
+ def generate(self, input, input_len=None, **cfg):
+ if self.vad_model is None:
+ return self.inference(input, input_len=input_len, **cfg)
+
+ else:
+ return self.inference_with_vad(input, input_len=input_len, **cfg)
+
+ def inference(
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
+ ):
+ kwargs = self.kwargs if kwargs is None else kwargs
+ if "cache" in kwargs:
+ kwargs.pop("cache")
+ deep_update(kwargs, cfg)
+ model = self.model if model is None else model
+ model.eval()
+
+ batch_size = kwargs.get("batch_size", 1)
+ # if kwargs.get("device", "cpu") == "cpu":
+ # batch_size = 1
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
+ )
+
+ speed_stats = {}
+ asr_result_list = []
+ num_samples = len(data_list)
+ disable_pbar = self.kwargs.get("disable_pbar", False)
+ pbar = (
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+ if not disable_pbar
+ else None
+ )
+ time_speech_total = 0.0
+ time_escape_total = 0.0
+ for beg_idx in range(0, num_samples, batch_size):
+ end_idx = min(num_samples, beg_idx + batch_size)
+ data_batch = data_list[beg_idx:end_idx]
+ key_batch = key_list[beg_idx:end_idx]
+ batch = {"data_in": data_batch, "key": key_batch}
+
+ if (end_idx - beg_idx) == 1 and kwargs.get(
+ "data_type", None
+ ) == "fbank": # fbank
+ batch["data_in"] = data_batch[0]
+ batch["data_lengths"] = input_len
+
+ time1 = time.perf_counter()
+ with torch.no_grad():
+ res = model.inference(**batch, **kwargs)
+ if isinstance(res, (list, tuple)):
+ results = res[0] if len(res) > 0 else [{"text": ""}]
+ meta_data = res[1] if len(res) > 1 else {}
+ time2 = time.perf_counter()
+
+ asr_result_list.extend(results)
+
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
+ batch_data_time = meta_data.get("batch_data_time", -1)
+ time_escape = time2 - time1
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
+ speed_stats["forward"] = f"{time_escape:0.3f}"
+ speed_stats["batch_size"] = f"{len(results)}"
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
+ description = f"{speed_stats}, "
+ if pbar:
+ pbar.update(end_idx - beg_idx)
+ pbar.set_description(description)
+ time_speech_total += batch_data_time
+ time_escape_total += time_escape
+
+ if pbar:
+ # pbar.update(1)
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
+ torch.cuda.empty_cache()
+ return asr_result_list
+
+ def vad(self, input, input_len=None, **cfg):
+ kwargs = self.kwargs
+ # step.1: compute the vad model
+ deep_update(self.vad_kwargs, cfg)
+ beg_vad = time.time()
+ res = self.inference(
+ input,
+ input_len=input_len,
+ model=self.vad_model,
+ kwargs=self.vad_kwargs,
+ **cfg,
+ )
+ end_vad = time.time()
+ # FIX(gcf): concat the vad clips for sense vocie model for better aed
+ if cfg.get("merge_vad", False):
+ for i in range(len(res)):
+ res[i]["value"] = merge_vad(
+ res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
+ )
+ elapsed = end_vad - beg_vad
+ return elapsed, res
+
+ def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
+
+ kwargs = self.kwargs
+
+ # step.2 compute asr model
+ model = self.model
+ deep_update(kwargs, cfg)
+ batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
+ kwargs["batch_size"] = batch_size
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=input_len, data_type=kwargs.get("data_type", None)
+ )
+ results_ret_list = []
+ time_speech_total_all_samples = 1e-6
+
+ beg_total = time.time()
+ pbar_total = (
+ tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
+ if not kwargs.get("disable_pbar", False)
+ else None
+ )
+
+ for i in range(len(vad_res)):
+ key = vad_res[i]["key"]
+ vadsegments = vad_res[i]["value"]
+ input_i = data_list[i]
+ fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
+ speech = load_audio_text_image_video(
+ input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
+ )
+ speech_lengths = len(speech)
+ n = len(vadsegments)
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
+ results_sorted = []
+
+ if not len(sorted_data):
+ results_ret_list.append({"key": key, "text": "", "timestamp": []})
+ logging.info("decoding, utt: {}, empty speech".format(key))
+ continue
+
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
+ batch_size = max(
+ batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
+ )
+
+ if kwargs["device"] == "cpu":
+ batch_size = 0
+
+ beg_idx = 0
+ beg_asr_total = time.time()
+ time_speech_total_per_sample = speech_lengths / 16000
+ time_speech_total_all_samples += time_speech_total_per_sample
+
+ # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
+
+ all_segments = []
+ max_len_in_batch = 0
+ end_idx = 1
+
+ for j, _ in enumerate(range(0, n)):
+ # pbar_sample.update(1)
+ sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
+ potential_batch_length = max(max_len_in_batch, sample_length) * (
+ j + 1 - beg_idx
+ )
+ # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
+ if (
+ j < n - 1
+ and sample_length < batch_size_threshold_ms
+ and potential_batch_length < batch_size
+ ):
+ max_len_in_batch = max(max_len_in_batch, sample_length)
+ end_idx += 1
+ continue
+
+ speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
+ speech, speech_lengths, sorted_data[beg_idx:end_idx]
+ )
+ results = self.inference(
+ speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
+ )
+
+ for _b in range(len(speech_j)):
+ results[_b]["interval"] = intervals[_b]
+
+ if self.spk_model is not None:
+ # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
+ for _b in range(len(speech_j)):
+ vad_segments = [
+ [
+ sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
+ sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
+ np.array(speech_j[_b]),
+ ]
+ ]
+ segments = sv_chunk(vad_segments)
+ all_segments.extend(segments)
+ speech_b = [i[2] for i in segments]
+ spk_res = self.inference(
+ speech_b,
+ input_len=None,
+ model=self.spk_model,
+ kwargs=kwargs,
+ **cfg,
+ )
+ results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
+
+ beg_idx = end_idx
+ end_idx += 1
+ max_len_in_batch = sample_length
+ if len(results) < 1:
+ continue
+ results_sorted.extend(results)
+
+ # end_asr_total = time.time()
+ # time_escape_total_per_sample = end_asr_total - beg_asr_total
+ # pbar_sample.update(1)
+ # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
+ # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
+
+ restored_data = [0] * n
+ for j in range(n):
+ index = sorted_data[j][1]
+ cur = results_sorted[j]
+ pattern = r"<\|([^|]+)\|>"
+ emotion_string = re.findall(pattern, cur["text"])
+ cur["text"] = re.sub(pattern, "", cur["text"])
+ cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
+ if self.punc_model is not None and len(cur["text"].strip()) > 0:
+ deep_update(self.punc_kwargs, cfg)
+ punc_res = self.inference(
+ cur["text"],
+ model=self.punc_model,
+ kwargs=self.punc_kwargs,
+ **cfg,
+ )
+ cur["text"] = punc_res[0]["text"]
+
+ restored_data[index] = cur
+
+ end_asr_total = time.time()
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
+ if pbar_total:
+ pbar_total.update(1)
+ pbar_total.set_description(
+ f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ f"time_speech: {time_speech_total_per_sample: 0.3f}, "
+ f"time_escape: {time_escape_total_per_sample:0.3f}"
+ )
+
+ # end_total = time.time()
+ # time_escape_total_all_samples = end_total - beg_total
+ # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
+ # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
+ # f"time_escape_all: {time_escape_total_all_samples:0.3f}")
+ return restored_data
+
+ def export(self, input=None, **cfg):
+ """
+
+ :param input:
+ :param type:
+ :param quantize:
+ :param fallback_num:
+ :param calib_num:
+ :param opset_version:
+ :param cfg:
+ :return:
+ """
+
+ device = cfg.get("device", "cpu")
+ model = self.model.to(device=device)
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ kwargs["device"] = device
+ del kwargs["model"]
+ model.eval()
+
+ type = kwargs.get("type", "onnx")
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=None, data_type=kwargs.get("data_type", None), key=None
+ )
+
+ with torch.no_grad():
+ export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
+
+ return export_dir
diff --git a/tools/sensevoice/fun_asr.py b/tools/sensevoice/fun_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..6789316d5186db69c021758094649553c3638f66
--- /dev/null
+++ b/tools/sensevoice/fun_asr.py
@@ -0,0 +1,332 @@
+import gc
+import os
+import re
+
+from audio_separator.separator import Separator
+
+os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
+os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
+import json
+import subprocess
+from pathlib import Path
+
+import click
+import torch
+from loguru import logger
+from pydub import AudioSegment
+from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
+from tools.sensevoice.auto_model import AutoModel
+
+
+def uvr5_cli(
+ audio_dir: Path,
+ output_folder: Path,
+ audio_files: list[Path] | None = None,
+ output_format: str = "flac",
+ model: str = "BS-Roformer-Viperx-1297.ckpt",
+):
+ # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
+ sepr = Separator(
+ model_file_dir=os.environ["UVR5_CACHE"],
+ output_dir=output_folder,
+ output_format=output_format,
+ )
+ dictmodel = {
+ "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
+ "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
+ "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
+ "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
+ }
+ roformer_model = dictmodel[model]
+ sepr.load_model(roformer_model)
+ if audio_files is None:
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+ total_files = len(audio_files)
+
+ print(f"{total_files} audio files found")
+
+ res = []
+ for audio in tqdm(audio_files, desc="Denoising: "):
+ file_path = str(audio_dir / audio)
+ sep_out = sepr.separate(file_path)
+ if isinstance(sep_out, str):
+ res.append(sep_out)
+ elif isinstance(sep_out, list):
+ res.extend(sep_out)
+ del sepr
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return res, roformer_model
+
+
+def get_sample_rate(media_path: Path):
+ result = subprocess.run(
+ [
+ "ffprobe",
+ "-v",
+ "quiet",
+ "-print_format",
+ "json",
+ "-show_streams",
+ str(media_path),
+ ],
+ capture_output=True,
+ text=True,
+ check=True,
+ )
+ media_info = json.loads(result.stdout)
+ for stream in media_info.get("streams", []):
+ if stream.get("codec_type") == "audio":
+ return stream.get("sample_rate")
+ return "44100" # Default sample rate if not found
+
+
+def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
+ sr = get_sample_rate(src_path)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ if src_path.resolve() == out_path.resolve():
+ output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
+ else:
+ output = str(out_path)
+ subprocess.run(
+ [
+ "ffmpeg",
+ "-loglevel",
+ "error",
+ "-i",
+ str(src_path),
+ "-acodec",
+ "pcm_s16le" if out_fmt == "wav" else "flac",
+ "-ar",
+ sr,
+ "-ac",
+ "1",
+ "-y",
+ output,
+ ],
+ check=True,
+ )
+ return out_path
+
+
+def convert_video_to_audio(video_path: Path, audio_dir: Path):
+ cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
+ vocals = [
+ p
+ for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
+ if p.suffix in AUDIO_EXTENSIONS
+ ]
+ if len(vocals) > 0:
+ return vocals[0]
+ audio_path = cur_dir / f"{video_path.stem}.wav"
+ convert_to_mono(video_path, audio_path)
+ return audio_path
+
+
+@click.command()
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option(
+ "--max_single_segment_time",
+ default=20000,
+ type=int,
+ help="Maximum of Output single audio duration(ms)",
+)
+@click.option("--fsmn-vad/--silero-vad", default=False)
+@click.option("--punc/--no-punc", default=False)
+@click.option("--denoise/--no-denoise", default=False)
+@click.option("--save_emo/--no_save_emo", default=False)
+def main(
+ audio_dir: str,
+ save_dir: str,
+ device: str,
+ language: str,
+ max_single_segment_time: int,
+ fsmn_vad: bool,
+ punc: bool,
+ denoise: bool,
+ save_emo: bool,
+):
+
+ audios_path = Path(audio_dir)
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ video_files = list_files(
+ path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
+ )
+ v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
+
+ if denoise:
+ VOCAL = "_(Vocals)"
+ original_files = [
+ p
+ for p in audios_path.glob("**/*")
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
+ ]
+
+ _, cur_model = uvr5_cli(
+ audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
+ )
+ need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
+ need_remove.extend(original_files)
+ for _ in need_remove:
+ _.unlink()
+ vocal_files = [
+ p
+ for p in audios_path.glob("**/*")
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
+ ]
+ for f in vocal_files:
+ fn, ext = f.stem, f.suffix
+
+ v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
+ if v_pos != -1:
+ new_fn = fn[: v_pos + len(VOCAL)]
+ new_f = f.with_name(new_fn + ext)
+ f = f.rename(new_f)
+ convert_to_mono(f, f, "flac")
+ f.unlink()
+
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+
+ logger.info("Loading / Downloading Funasr model...")
+
+ model_dir = "iic/SenseVoiceSmall"
+
+ vad_model = "fsmn-vad" if fsmn_vad else None
+ vad_kwargs = {"max_single_segment_time": max_single_segment_time}
+ punc_model = "ct-punc" if punc else None
+
+ manager = AutoModel(
+ model=model_dir,
+ trust_remote_code=False,
+ vad_model=vad_model,
+ vad_kwargs=vad_kwargs,
+ punc_model=punc_model,
+ device=device,
+ )
+
+ if not fsmn_vad and vad_model is None:
+ vad_model = load_silero_vad()
+
+ logger.info("Model loaded.")
+
+ pattern = re.compile(r"_\d{3}\.")
+
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+
+ if pattern.search(file_path.name):
+ # logger.info(f"Skipping {file_path} as it has already been processed.")
+ continue
+
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ audio = AudioSegment.from_file(file_path)
+
+ cfg = dict(
+ cache={},
+ language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
+ use_itn=False,
+ batch_size_s=60,
+ )
+
+ if fsmn_vad:
+ elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
+ else:
+ wav = read_audio(
+ str(file_path)
+ ) # backend (sox, soundfile, or ffmpeg) required!
+ audio_key = file_path.stem
+ audio_val = []
+ speech_timestamps = get_speech_timestamps(
+ wav,
+ vad_model,
+ max_speech_duration_s=max_single_segment_time // 1000,
+ return_seconds=True,
+ )
+
+ audio_val = [
+ [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
+ for timestamp in speech_timestamps
+ ]
+ vad_res = []
+ vad_res.append(dict(key=audio_key, value=audio_val))
+
+ res = manager.inference_with_vadres(
+ input=str(file_path), vad_res=vad_res, **cfg
+ )
+
+ for i, info in enumerate(res):
+ [start_ms, end_ms] = info["interval"]
+ text = info["text"]
+ emo = info["emo"]
+ sliced_audio = audio[start_ms:end_ms]
+ audio_save_path = (
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
+ )
+ sliced_audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}: {text}")
+
+ transcript_save_path = (
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
+ )
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(text)
+
+ if save_emo:
+ emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
+ with open(
+ emo_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(emo)
+
+ if audios_path.resolve() == save_path.resolve():
+ file_path.unlink()
+
+
+if __name__ == "__main__":
+ main()
+ exit(0)
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
+
+ # Load the audio file
+ audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
+ model_dir = "iic/SenseVoiceSmall"
+ m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
+ m.eval()
+
+ res = m.inference(
+ data_in=f"{kwargs['model_path']}/example/zh.mp3",
+ language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
+ use_itn=False,
+ ban_emo_unk=False,
+ **kwargs,
+ )
+
+ print(res)
+ text = rich_transcription_postprocess(res[0][0]["text"])
+ print(text)
diff --git a/tools/sensevoice/vad_utils.py b/tools/sensevoice/vad_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bef75ed8c2841701fff44f7130e91ef8dfdf8cc
--- /dev/null
+++ b/tools/sensevoice/vad_utils.py
@@ -0,0 +1,61 @@
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+def slice_padding_fbank(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ for i, segment in enumerate(vad_segments):
+
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
+ speech_i = speech[0, bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
+ speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
+ return feats_pad, speech_lengths_pad
+
+
+def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ intervals = []
+ for i, segment in enumerate(vad_segments):
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths)
+ speech_i = speech[bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ intervals.append([bed_idx // 16, end_idx // 16])
+
+ return speech_list, speech_lengths_list, intervals
+
+
+def merge_vad(vad_result, max_length=15000, min_length=0):
+ new_result = []
+ if len(vad_result) <= 1:
+ return vad_result
+ time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
+ time_step = sorted(list(set(time_step)))
+ if len(time_step) == 0:
+ return []
+ bg = 0
+ for i in range(len(time_step) - 1):
+ time = time_step[i]
+ if time_step[i + 1] - bg < max_length:
+ continue
+ if time - bg > min_length:
+ new_result.append([bg, time])
+ # if time - bg < max_length * 1.5:
+ # new_result.append([bg, time])
+ # else:
+ # split_num = int(time - bg) // max_length + 1
+ # spl_l = int(time - bg) // split_num
+ # for j in range(split_num):
+ # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
+ bg = time
+ new_result.append([bg, time_step[-1]])
+ return new_result
diff --git a/tools/server/__pycache__/api_utils.cpython-310.pyc b/tools/server/__pycache__/api_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..114a18b10a43d715049ff1de9987ee802cbfbe5a
Binary files /dev/null and b/tools/server/__pycache__/api_utils.cpython-310.pyc differ
diff --git a/tools/server/__pycache__/exception_handler.cpython-310.pyc b/tools/server/__pycache__/exception_handler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0113e2c69432818c79228f22a8915dd2fdfbf2eb
Binary files /dev/null and b/tools/server/__pycache__/exception_handler.cpython-310.pyc differ
diff --git a/tools/server/__pycache__/inference.cpython-310.pyc b/tools/server/__pycache__/inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e0845f16e5aca47858244c34102c5df46aaed19
Binary files /dev/null and b/tools/server/__pycache__/inference.cpython-310.pyc differ
diff --git a/tools/server/__pycache__/model_manager.cpython-310.pyc b/tools/server/__pycache__/model_manager.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a88dd61676610d2f793899f3dab199c9396681a0
Binary files /dev/null and b/tools/server/__pycache__/model_manager.cpython-310.pyc differ
diff --git a/tools/server/__pycache__/model_utils.cpython-310.pyc b/tools/server/__pycache__/model_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..936fe18898f7b92bde8227d8cee58d5eae94c97b
Binary files /dev/null and b/tools/server/__pycache__/model_utils.cpython-310.pyc differ
diff --git a/tools/server/__pycache__/views.cpython-310.pyc b/tools/server/__pycache__/views.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..991be74edbf77e5a06a21d170a12f7adfc13774e
Binary files /dev/null and b/tools/server/__pycache__/views.cpython-310.pyc differ
diff --git a/tools/server/agent/__init__.py b/tools/server/agent/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4b0a2ecd22dafaf696c09218b5ed73fd4ef56b3
--- /dev/null
+++ b/tools/server/agent/__init__.py
@@ -0,0 +1,57 @@
+import struct
+from functools import partial
+
+import ormsgpack
+
+from tools.server.agent.generate import generate_responses
+from tools.server.agent.pre_generation_utils import prepare_messages
+
+
+def execute_request(input_queue, tokenizer, config, request, device):
+ """
+ This function prepares the conversation, encodes the request,
+ sends the generation request, and handles decoding/streaming.
+ It returns a response generator (ServeResponse or ServeStreamResponse).
+ """
+ prompt, im_end_id = prepare_messages(request, tokenizer, config)
+ yield from generate_responses(
+ input_queue, tokenizer, config, request, prompt, im_end_id, device
+ )
+
+
+def response_generator(req, llama_queue, tokenizer, config, device):
+ """
+ Non-streaming response wrapper for the chat endpoint.
+ Only returns the final result.
+ """
+ generator = execute_request(llama_queue, tokenizer, config, req, device)
+ return next(generator)
+
+
+async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode):
+ """
+ Streaming response wrapper for the chat endpoint.
+ Returns the response in chunks.
+ """
+ generator = execute_request(llama_queue, tokenizer, config, req, device)
+ for i in generator:
+ if json_mode:
+ body = i.model_dump_json().encode("utf-8")
+ yield b"data: " + body + b"\n\n"
+ else:
+ body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
+ yield struct.pack("I", len(body)) + body
+
+
+def get_response_generator(
+ llama_queue, tokenizer, config, req, device, json_mode
+) -> partial:
+ """
+ Get the correct response generator based on the request.
+ """
+ if not req.streaming:
+ return partial(response_generator, req, llama_queue, tokenizer, config, device)
+ else:
+ return partial(
+ streaming_generator, req, llama_queue, tokenizer, config, device, json_mode
+ )
diff --git a/tools/server/agent/__pycache__/__init__.cpython-310.pyc b/tools/server/agent/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6111f9438e40d6b4fe6b6ab5f071f84f2b62543
Binary files /dev/null and b/tools/server/agent/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tools/server/agent/__pycache__/generate.cpython-310.pyc b/tools/server/agent/__pycache__/generate.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53e2701a12d25bd13372ffa5d981fcda6adf9086
Binary files /dev/null and b/tools/server/agent/__pycache__/generate.cpython-310.pyc differ
diff --git a/tools/server/agent/__pycache__/generation_utils.cpython-310.pyc b/tools/server/agent/__pycache__/generation_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e5651e478ee12ec4048bcfaf6c3780291aaba06
Binary files /dev/null and b/tools/server/agent/__pycache__/generation_utils.cpython-310.pyc differ
diff --git a/tools/server/agent/__pycache__/pre_generation_utils.cpython-310.pyc b/tools/server/agent/__pycache__/pre_generation_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6823f323c2569e4f7218352b624388cf133b14f
Binary files /dev/null and b/tools/server/agent/__pycache__/pre_generation_utils.cpython-310.pyc differ
diff --git a/tools/server/agent/generate.py b/tools/server/agent/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef4ae9a7b49034b58d4cb9db055f3d05f92b07b9
--- /dev/null
+++ b/tools/server/agent/generate.py
@@ -0,0 +1,119 @@
+import time
+
+from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse
+from tools.server.agent.generation_utils import (
+ initialize_decode_buffers,
+ process_response_tokens,
+ send_reset_buffer,
+)
+from tools.server.agent.pre_generation_utils import (
+ create_generation_request,
+ send_generation_request,
+)
+
+
+def generate_responses(
+ input_queue, tokenizer, config, request, prompt, im_end_id, device
+):
+ """
+ Main generation function that handles the conversation, encodes the request,
+ sends the generation request, and handles decoding/streaming.
+ It returns a response generator (ServeResponse or ServeStreamResponse).
+ """
+ stats = {}
+ start = time.time()
+ stats["start_time"] = start
+ stats["tokens_count"] = 0
+
+ # Prepare and send the generation request
+ req = create_generation_request(prompt, request, im_end_id, device)
+ response_queue = send_generation_request(input_queue, req)
+ decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
+
+ while True:
+ response = response_queue.get()
+
+ # Handle abnormal finish or error
+ if response in ["stop", "error"]:
+ finish_reason = response
+ break
+
+ # Process the response tokens
+ is_first_token = stats["tokens_count"] == 0
+ responses = process_response_tokens(
+ response,
+ tokenizer,
+ config,
+ request,
+ decode_buffer,
+ parts,
+ finished,
+ im_end_id,
+ stats,
+ start,
+ is_first_token,
+ )
+
+ # Yield the responses if streaming
+ if request.streaming and responses:
+ for r in responses:
+ yield r
+
+ stats["tokens_count"] += 1
+
+ # Check if all samples are finished
+ if all(finished):
+ finish_reason = "stop"
+ break
+
+ # Finalize the response
+ final_responses = finalize_response(
+ request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
+ )
+ for fr in final_responses:
+ yield fr
+
+
+def finalize_response(
+ request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
+):
+ """
+ Finalize the response by sending the remaining text buffers.
+ """
+ responses = []
+
+ # Send the remaining text buffers
+ for sample_id in range(request.num_samples):
+ responses.extend(
+ send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
+ )
+
+ # Calculate the final stats
+ stats["total_time"] = (time.time() - stats["start_time"]) * 1000
+ stats["total_tokens"] = stats["tokens_count"]
+
+ # If streaming, send the final chunks for each sample
+ if request.streaming:
+ for sample_id in range(request.num_samples):
+ if finished[sample_id]:
+ continue
+ responses.append(
+ ServeStreamResponse(
+ finish_reason=finish_reason, stats=stats, sample_id=sample_id
+ )
+ )
+ else:
+ # If not streaming, send the full messages for each sample
+ full_messages = [
+ ServeMessage(role="assistant", parts=parts[i])
+ for i in range(request.num_samples)
+ ]
+ responses.append(
+ ServeResponse(
+ messages=full_messages,
+ finish_reason=finish_reason,
+ stats=stats,
+ )
+ )
+
+ return responses
diff --git a/tools/server/agent/generation_utils.py b/tools/server/agent/generation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc2dd4e4dbbb10630ce09e8c44574c7c43aab26d
--- /dev/null
+++ b/tools/server/agent/generation_utils.py
@@ -0,0 +1,122 @@
+import time
+
+from tools.schema import (
+ ServeStreamDelta,
+ ServeStreamResponse,
+ ServeTextPart,
+ ServeVQPart,
+)
+
+
+def initialize_decode_buffers(num_samples):
+ """Initialise the decode buffers for each sample."""
+ decode_buffer = [[] for _ in range(num_samples)]
+ parts = [[] for _ in range(num_samples)]
+ finished = [False for _ in range(num_samples)]
+ return decode_buffer, parts, finished
+
+
+def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request):
+ """Send the remaining text buffer for a sample."""
+ if len(decode_buffer[sample_id]) == 0:
+ return []
+
+ decoded = tokenizer.decode(decode_buffer[sample_id])
+ part = ServeTextPart(text=decoded)
+
+ responses = []
+ if request.streaming:
+ responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part)))
+ else:
+ parts[sample_id].append(part)
+
+ decode_buffer[sample_id] = []
+ return responses
+
+
+def handle_semantic_tokens(tokens, config, sample_id, parts, request):
+ """Handle the semantic tokens returned by the model."""
+ responses = []
+ _tokens = tokens[1:].clone()
+
+ if not config.share_codebook_embeddings:
+ for i in range(len(_tokens)):
+ _tokens[i] -= config.codebook_size * i
+
+ # If streaming, send the VQ parts directly
+ if request.streaming:
+ responses.append(
+ ServeStreamResponse(
+ sample_id=sample_id,
+ delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
+ )
+ )
+ else:
+ # If not streaming, accumulate the VQ parts
+ if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart):
+ parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
+ else:
+ # Accumulate the codes
+ for codebook_id, value in enumerate(_tokens):
+ parts[sample_id][-1].codes[codebook_id].append(value.item())
+
+ return responses
+
+
+def process_response_tokens(
+ response,
+ tokenizer,
+ config,
+ request,
+ decode_buffer,
+ parts,
+ finished,
+ im_end_id,
+ stats,
+ start,
+ is_first_token,
+):
+ """Process the response tokens returned by the model."""
+ responses = []
+ for sample_id, tokens in enumerate(response):
+ if finished[sample_id]:
+ continue
+
+ # End of the conversation
+ if tokens[0] == im_end_id:
+ finished[sample_id] = True
+ # Send the remaining text buffer
+ responses.extend(
+ send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
+ )
+ if request.streaming:
+ responses.append(
+ ServeStreamResponse(
+ sample_id=sample_id,
+ finish_reason="stop",
+ stats=stats,
+ )
+ )
+ continue
+
+ # Check if the token is semantic
+ is_semantic = (
+ tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
+ )
+
+ if is_semantic:
+ # Before the semantic tokens, send the remaining text buffer
+ responses.extend(
+ send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
+ )
+ responses.extend(
+ handle_semantic_tokens(tokens, config, sample_id, parts, request)
+ )
+ else:
+ # Accumulate the text tokens (not implemented?)
+ decode_buffer[sample_id].append(tokens[0, 0])
+
+ if is_first_token:
+ stats["time_to_first_token"] = (time.time() - start) * 1000
+
+ return responses
diff --git a/tools/server/agent/pre_generation_utils.py b/tools/server/agent/pre_generation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..135a72e396841f6ebfab340bb497d984238b78d5
--- /dev/null
+++ b/tools/server/agent/pre_generation_utils.py
@@ -0,0 +1,72 @@
+import queue
+
+from fish_speech.conversation import Conversation, Message
+from fish_speech.tokenizer import IM_END_TOKEN
+from tools.llama.generate import GenerateRequest
+
+
+def prepare_messages(request, tokenizer, config):
+ """
+ Reorganise the provided list of messages into a conversation.
+ Encode the conversation for inference.
+ """
+ # Convert the messages to ConversationMessage objects
+ messages = [msg.to_conversation_message() for msg in request.messages]
+
+ if len(messages) < 1:
+ raise ValueError("At least one message is required")
+
+ # Check the last message to determine the next step
+ last_role = messages[-1].role
+ match last_role:
+ case "user":
+ # The last message is from the user, ask the assistant to respond with a new message
+ messages.append(
+ Message(role="assistant", parts=[], add_im_end=False, modality="voice")
+ )
+ case "raw":
+ # The last message is raw text, ask the assistant to complete it
+ messages[-1].add_im_start = False
+ messages[-1].add_im_end = False
+ messages[-1].modality = "voice"
+ case "assistant":
+ # The last message is from the assistant, ask the assistant to continue
+ messages[-1].add_im_end = False
+ case _:
+ # We expect it to be assistant if not user or raw
+ raise ValueError("The last message must be from the assistant, user or raw")
+
+ # Create a conversation object and encode it for inference
+ conv = Conversation(messages=messages)
+ prompt = conv.encode_for_inference(
+ tokenizer=tokenizer, num_codebooks=config.num_codebooks
+ )
+ im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
+
+ return prompt, im_end_id
+
+
+def create_generation_request(prompt, request, im_end_id, device):
+ """
+ Convert the request into a dictionary that can be sent to the model for generation.
+ """
+ req = {
+ "prompt": prompt.to(device),
+ "max_new_tokens": request.max_new_tokens,
+ "im_end_id": im_end_id,
+ "temperature": request.temperature,
+ "top_p": request.top_p,
+ "repetition_penalty": request.repetition_penalty,
+ "num_samples": request.num_samples,
+ "early_stop_threshold": request.early_stop_threshold,
+ }
+ return req
+
+
+def send_generation_request(input_queue, req):
+ """
+ Send the generation request to the model and return a queue to get the response.
+ """
+ response_queue = queue.Queue()
+ input_queue.put(GenerateRequest(req, response_queue))
+ return response_queue
diff --git a/tools/server/api_utils.py b/tools/server/api_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cfe4c3a22eb9ecb6d4966ad03bd8205202437bf
--- /dev/null
+++ b/tools/server/api_utils.py
@@ -0,0 +1,75 @@
+from argparse import ArgumentParser
+from http import HTTPStatus
+from typing import Annotated, Any
+
+import ormsgpack
+from baize.datastructures import ContentType
+from kui.asgi import HTTPException, HttpRequest
+
+from tools.inference_engine import TTSInferenceEngine
+from tools.schema import ServeTTSRequest
+from tools.server.inference import inference_wrapper as inference
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
+ parser.add_argument("--load-asr-model", action="store_true")
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=str,
+ default="checkpoints/fish-speech-1.5",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=str,
+ default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-text-length", type=int, default=0)
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
+ parser.add_argument("--workers", type=int, default=1)
+
+ return parser.parse_args()
+
+
+class MsgPackRequest(HttpRequest):
+ async def data(
+ self,
+ ) -> Annotated[
+ Any, ContentType("application/msgpack"), ContentType("application/json")
+ ]:
+ if self.content_type == "application/msgpack":
+ return ormsgpack.unpackb(await self.body)
+
+ elif self.content_type == "application/json":
+ return await self.json
+
+ raise HTTPException(
+ HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
+ headers={"Accept": "application/msgpack, application/json"},
+ )
+
+
+async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
+ for chunk in inference(req, engine):
+ if isinstance(chunk, bytes):
+ yield chunk
+
+
+async def buffer_to_async_generator(buffer):
+ yield buffer
+
+
+def get_content_type(audio_format):
+ if audio_format == "wav":
+ return "audio/wav"
+ elif audio_format == "flac":
+ return "audio/flac"
+ elif audio_format == "mp3":
+ return "audio/mpeg"
+ else:
+ return "application/octet-stream"
diff --git a/tools/server/exception_handler.py b/tools/server/exception_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..07d595fabb7af4e00a1fb67a78b466fea0c2c0f4
--- /dev/null
+++ b/tools/server/exception_handler.py
@@ -0,0 +1,27 @@
+import traceback
+from http import HTTPStatus
+
+from kui.asgi import HTTPException, JSONResponse
+
+
+class ExceptionHandler:
+
+ async def http_exception_handler(self, exc: HTTPException):
+ return JSONResponse(
+ dict(
+ statusCode=exc.status_code,
+ message=exc.content,
+ error=HTTPStatus(exc.status_code).phrase,
+ ),
+ exc.status_code,
+ exc.headers,
+ )
+
+ async def other_exception_handler(self, exc: Exception):
+ traceback.print_exc()
+
+ status = HTTPStatus.INTERNAL_SERVER_ERROR
+ return JSONResponse(
+ dict(statusCode=status, message=str(exc), error=status.phrase),
+ status,
+ )
diff --git a/tools/server/inference.py b/tools/server/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cfdceeacd0ede5ffa7a369eab80294b42fdd22f
--- /dev/null
+++ b/tools/server/inference.py
@@ -0,0 +1,45 @@
+from http import HTTPStatus
+
+import numpy as np
+from kui.asgi import HTTPException
+
+from tools.inference_engine import TTSInferenceEngine
+from tools.schema import ServeTTSRequest
+
+AMPLITUDE = 32768 # Needs an explaination
+
+
+def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
+ """
+ Wrapper for the inference function.
+ Used in the API server.
+ """
+ count = 0
+ for result in engine.inference(req):
+ match result.code:
+ case "header":
+ if isinstance(result.audio, tuple):
+ yield result.audio[1]
+
+ case "error":
+ raise HTTPException(
+ HTTPStatus.INTERNAL_SERVER_ERROR,
+ content=str(result.error),
+ )
+
+ case "segment":
+ count += 1
+ if isinstance(result.audio, tuple):
+ yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
+
+ case "final":
+ count += 1
+ if isinstance(result.audio, tuple):
+ yield result.audio[1]
+ return None # Stop the generator
+
+ if count == 0:
+ raise HTTPException(
+ HTTPStatus.INTERNAL_SERVER_ERROR,
+ content="No audio generated, please check the input text.",
+ )
diff --git a/tools/server/model_manager.py b/tools/server/model_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d2cceb6524fb12663328fe9fe5cf661484cbb66
--- /dev/null
+++ b/tools/server/model_manager.py
@@ -0,0 +1,122 @@
+import torch
+from funasr import AutoModel
+from loguru import logger
+
+from tools.inference_engine import TTSInferenceEngine
+from tools.llama.generate import (
+ launch_thread_safe_queue,
+ launch_thread_safe_queue_agent,
+)
+from tools.schema import ServeTTSRequest
+from tools.server.inference import inference_wrapper as inference
+from tools.vqgan.inference import load_model as load_decoder_model
+
+ASR_MODEL_NAME = "iic/SenseVoiceSmall"
+
+
+class ModelManager:
+ def __init__(
+ self,
+ mode: str,
+ device: str,
+ half: bool,
+ compile: bool,
+ asr_enabled: bool,
+ llama_checkpoint_path: str,
+ decoder_checkpoint_path: str,
+ decoder_config_name: str,
+ ) -> None:
+
+ self.mode = mode
+ self.device = device
+ self.half = half
+ self.compile = compile
+
+ self.precision = torch.half if half else torch.bfloat16
+
+ # Check if MPS or CUDA is available
+ if torch.backends.mps.is_available():
+ self.device = "mps"
+ logger.info("mps is available, running on mps.")
+ elif not torch.cuda.is_available():
+ self.device = "cpu"
+ logger.info("CUDA is not available, running on CPU.")
+
+ # Load the ASR model if enabled
+ if asr_enabled:
+ self.load_asr_model(self.device)
+
+ # Load the TTS models
+ self.load_llama_model(
+ llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
+ )
+ self.load_decoder_model(
+ decoder_config_name, decoder_checkpoint_path, self.device
+ )
+ self.tts_inference_engine = TTSInferenceEngine(
+ llama_queue=self.llama_queue,
+ decoder_model=self.decoder_model,
+ precision=self.precision,
+ compile=self.compile,
+ )
+
+ # Warm up the models
+ if self.mode == "tts":
+ self.warm_up(self.tts_inference_engine)
+
+ def load_asr_model(self, device, hub="ms") -> None:
+ self.asr_model = AutoModel(
+ model=ASR_MODEL_NAME,
+ device=device,
+ disable_pbar=True,
+ hub=hub,
+ )
+ logger.info("ASR model loaded.")
+
+ def load_llama_model(
+ self, checkpoint_path, device, precision, compile, mode
+ ) -> None:
+
+ if mode == "tts":
+ self.llama_queue = launch_thread_safe_queue(
+ checkpoint_path=checkpoint_path,
+ device=device,
+ precision=precision,
+ compile=compile,
+ )
+ elif mode == "agent":
+ self.llama_queue, self.tokenizer, self.config = (
+ launch_thread_safe_queue_agent(
+ checkpoint_path=checkpoint_path,
+ device=device,
+ precision=precision,
+ compile=compile,
+ )
+ )
+ else:
+ raise ValueError(f"Invalid mode: {mode}")
+
+ logger.info("LLAMA model loaded.")
+
+ def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
+ self.decoder_model = load_decoder_model(
+ config_name=config_name,
+ checkpoint_path=checkpoint_path,
+ device=device,
+ )
+ logger.info("Decoder model loaded.")
+
+ def warm_up(self, tts_inference_engine) -> None:
+ request = ServeTTSRequest(
+ text="Hello world.",
+ references=[],
+ reference_id=None,
+ max_new_tokens=1024,
+ chunk_length=200,
+ top_p=0.7,
+ repetition_penalty=1.2,
+ temperature=0.7,
+ format="wav",
+ )
+ list(inference(request, tts_inference_engine))
+ logger.info("Models warmed up.")
diff --git a/tools/server/model_utils.py b/tools/server/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5a4c3a0c176bc4466b12a0f24f57f3f8299a265
--- /dev/null
+++ b/tools/server/model_utils.py
@@ -0,0 +1,129 @@
+import io
+import re
+
+import librosa
+import torch
+import torchaudio
+from cachetools import LRUCache, cached
+
+CACHE_MAXSIZE = 10000
+MICRO_BATCH_SIZE = 8
+ASR_SAMPLE_RATE = 16000
+HUGE_GAP_THRESHOLD = 4000
+
+
+@torch.no_grad()
+@torch.autocast(device_type="cuda", dtype=torch.half)
+def batch_encode(model, audios_list: list[bytes]):
+ audios: list[torch.Tensor] = [
+ (
+ torch.from_numpy(
+ librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
+ )[None]
+ if isinstance(audio, bytes)
+ else audio
+ )
+ for audio in audios_list
+ ]
+
+ lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
+ max_length = lengths.max().item()
+
+ print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
+
+ padded = torch.stack(
+ [
+ torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
+ for audio in audios
+ ]
+ ).to(model.device)
+
+ features, feature_lengths = model.encode(padded, audio_lengths=lengths)
+ features, feature_lengths = features.cpu(), feature_lengths.cpu()
+
+ return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
+
+
+@cached(
+ cache=LRUCache(maxsize=CACHE_MAXSIZE),
+ key=lambda model, audios: (model.device, tuple(audios)),
+)
+def cached_vqgan_batch_encode(model, audios: list[bytes]):
+ return batch_encode(model, audios)
+
+
+@torch.no_grad()
+@torch.autocast(device_type="cuda", dtype=torch.half)
+def vqgan_decode(model, features):
+ lengths = torch.tensor(
+ [feature.shape[-1] for feature in features], device=model.device
+ )
+ max_length = lengths.max().item()
+ padded = torch.stack(
+ [
+ torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
+ for feature in features
+ ]
+ ).to(model.device)
+
+ # If bs too large, we do micro batch decode
+ audios, audio_lengths = [], []
+ for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
+ audio, audio_length = model.decode(
+ padded[i : i + MICRO_BATCH_SIZE],
+ feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
+ )
+ audios.append(audio)
+ audio_lengths.append(audio_length)
+ audios = torch.cat(audios, dim=0)
+ audio_lengths = torch.cat(audio_lengths, dim=0)
+ audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
+
+ return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
+
+
+@torch.no_grad()
+def batch_asr(model, lock, audios, sr, language="auto"):
+ resampled_audios = []
+ for audio in audios:
+ audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
+ assert audio.ndim == 1
+ resampled_audios.append(audio)
+
+ with lock:
+ res = model.generate(
+ input=resampled_audios,
+ batch_size=len(resampled_audios),
+ language=language,
+ use_itn=True,
+ )
+
+ results = []
+ for r, audio in zip(res, audios):
+ text = r["text"]
+ text = re.sub(r"<\|.*?\|>", "", text)
+ duration = len(audio) / sr * 1000
+ huge_gap = False
+
+ if "timestamp" in r and len(r["timestamp"]) > 2:
+ for timestamp_a, timestamp_b in zip(
+ r["timestamp"][:-1], r["timestamp"][1:]
+ ):
+ # If there is a gap of more than 4 seconds, we consider it as a huge gap
+ if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
+ huge_gap = True
+ break
+
+ # Doesn't make sense to have a huge gap at the end
+ if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
+ huge_gap = True
+
+ results.append(
+ {
+ "text": text,
+ "duration": duration,
+ "huge_gap": huge_gap,
+ }
+ )
+
+ return results
diff --git a/tools/server/views.py b/tools/server/views.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f54fa0c10924aa7409338a94289def9392375ce
--- /dev/null
+++ b/tools/server/views.py
@@ -0,0 +1,246 @@
+import io
+import os
+import time
+from http import HTTPStatus
+
+import numpy as np
+import ormsgpack
+import soundfile as sf
+import torch
+from kui.asgi import HTTPException, HttpView, JSONResponse, StreamResponse, request
+from loguru import logger
+
+from tools.schema import (
+ ServeASRRequest,
+ ServeASRResponse,
+ ServeChatRequest,
+ ServeTTSRequest,
+ ServeVQGANDecodeRequest,
+ ServeVQGANDecodeResponse,
+ ServeVQGANEncodeRequest,
+ ServeVQGANEncodeResponse,
+)
+from tools.server.agent import get_response_generator
+from tools.server.api_utils import (
+ buffer_to_async_generator,
+ get_content_type,
+ inference_async,
+)
+from tools.server.inference import inference_wrapper as inference
+from tools.server.model_manager import ModelManager
+from tools.server.model_utils import batch_asr, cached_vqgan_batch_encode, vqgan_decode
+
+MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
+
+
+class HealthView(HttpView):
+ """
+ Return the health status of the server.
+ """
+
+ @classmethod
+ async def post(cls):
+ return JSONResponse({"status": "ok"})
+
+
+class VQGANEncodeView(HttpView):
+ """
+ Encode the audio into symbolic tokens.
+ """
+
+ @classmethod
+ async def post(cls):
+ # Decode the request
+ payload = await request.data()
+ req = ServeVQGANEncodeRequest(**payload)
+
+ # Get the model from the app
+ model_manager: ModelManager = request.app.state.model_manager
+ decoder_model = model_manager.decoder_model
+
+ # Encode the audio
+ start_time = time.time()
+ tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
+ logger.info(
+ f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
+ )
+
+ # Return the response
+ return ormsgpack.packb(
+ ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+ )
+
+
+class VQGANDecodeView(HttpView):
+ """
+ Decode the symbolic tokens into audio.
+ """
+
+ @classmethod
+ async def post(cls):
+ # Decode the request
+ payload = await request.data()
+ req = ServeVQGANDecodeRequest(**payload)
+
+ # Get the model from the app
+ model_manager: ModelManager = request.app.state.model_manager
+ decoder_model = model_manager.decoder_model
+
+ # Decode the audio
+ tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
+ start_time = time.time()
+ audios = vqgan_decode(decoder_model, tokens)
+ logger.info(
+ f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
+ )
+ audios = [audio.astype(np.float16).tobytes() for audio in audios]
+
+ # Return the response
+ return ormsgpack.packb(
+ ServeVQGANDecodeResponse(audios=audios),
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+ )
+
+
+class ASRView(HttpView):
+ """
+ Perform automatic speech recognition on the audio.
+ """
+
+ @classmethod
+ async def post(cls):
+ # Decode the request
+ payload = await request.data()
+ req = ServeASRRequest(**payload)
+
+ # Get the model from the app
+ model_manager: ModelManager = request.app.state.model_manager
+ asr_model = model_manager.asr_model
+ lock = request.app.state.lock
+
+ # Perform ASR
+ start_time = time.time()
+ audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
+ audios = [torch.from_numpy(audio).float() for audio in audios]
+
+ if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
+ raise HTTPException(status_code=400, content="Audio length is too long")
+
+ transcriptions = batch_asr(
+ asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
+ )
+ logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
+
+ # Return the response
+ return ormsgpack.packb(
+ ServeASRResponse(transcriptions=transcriptions),
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+ )
+
+
+class TTSView(HttpView):
+ """
+ Perform text-to-speech on the input text.
+ """
+
+ @classmethod
+ async def post(cls):
+ # Decode the request
+ payload = await request.data()
+ req = ServeTTSRequest(**payload)
+
+ # Get the model from the app
+ app_state = request.app.state
+ model_manager: ModelManager = app_state.model_manager
+ engine = model_manager.tts_inference_engine
+ sample_rate = engine.decoder_model.spec_transform.sample_rate
+
+ # Check if the text is too long
+ if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST,
+ content=f"Text is too long, max length is {app_state.max_text_length}",
+ )
+
+ # Check if streaming is enabled
+ if req.streaming and req.format != "wav":
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST,
+ content="Streaming only supports WAV format",
+ )
+
+ # Perform TTS
+ if req.streaming:
+ return StreamResponse(
+ iterable=inference_async(req, engine),
+ headers={
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
+ },
+ content_type=get_content_type(req.format),
+ )
+ else:
+ fake_audios = next(inference(req, engine))
+ buffer = io.BytesIO()
+ sf.write(
+ buffer,
+ fake_audios,
+ sample_rate,
+ format=req.format,
+ )
+
+ return StreamResponse(
+ iterable=buffer_to_async_generator(buffer.getvalue()),
+ headers={
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
+ },
+ content_type=get_content_type(req.format),
+ )
+
+
+class ChatView(HttpView):
+ """
+ Perform chatbot inference on the input text.
+ """
+
+ @classmethod
+ async def post(cls):
+ # Decode the request
+ payload = await request.data()
+ req = ServeChatRequest(**payload)
+
+ # Check that the number of samples requested is correct
+ if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST,
+ content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
+ )
+
+ # Get the type of content provided
+ content_type = request.headers.get("Content-Type", "application/json")
+ json_mode = "application/json" in content_type
+
+ # Get the models from the app
+ model_manager: ModelManager = request.app.state.model_manager
+ llama_queue = model_manager.llama_queue
+ tokenizer = model_manager.tokenizer
+ config = model_manager.config
+
+ device = request.app.state.device
+
+ # Get the response generators
+ response_generator = get_response_generator(
+ llama_queue, tokenizer, config, req, device, json_mode
+ )
+
+ # Return the response in the correct format
+ if req.streaming is False:
+ result = response_generator()
+ if json_mode:
+ return JSONResponse(result.model_dump())
+ else:
+ return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
+
+ return StreamResponse(
+ iterable=response_generator(), content_type="text/event-stream"
+ )
diff --git a/tools/smart_pad.py b/tools/smart_pad.py
new file mode 100644
index 0000000000000000000000000000000000000000..de9dc154f26b2869a7e34f7d4cd95db741ee4c6a
--- /dev/null
+++ b/tools/smart_pad.py
@@ -0,0 +1,60 @@
+import random
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import librosa
+import torch.nn.functional as F
+import torchaudio
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files
+
+threshold = 10 ** (-50 / 20.0)
+
+
+def process(file):
+ waveform, sample_rate = torchaudio.load(str(file), backend="sox")
+ if waveform.size(0) > 1:
+ waveform = waveform.mean(dim=0, keepdim=True)
+
+ loudness = librosa.feature.rms(
+ y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
+ )[0]
+
+ for i in range(len(loudness) - 1, 0, -1):
+ if loudness[i] > threshold:
+ break
+
+ end_silent_time = (len(loudness) - i) * 512 / sample_rate
+
+ if end_silent_time <= 0.3:
+ random_time = random.uniform(0.3, 0.7) - end_silent_time
+ waveform = F.pad(
+ waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
+ )
+
+ for i in range(len(loudness)):
+ if loudness[i] > threshold:
+ break
+
+ start_silent_time = i * 512 / sample_rate
+
+ if start_silent_time > 0.02:
+ waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
+
+ torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
+
+
+@click.command()
+@click.argument("source", type=Path)
+@click.option("--num-workers", type=int, default=12)
+def main(source, num_workers):
+ files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
+
+ with Pool(num_workers) as p:
+ list(tqdm(p.imap_unordered(process, files), total=len(files)))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/__pycache__/inference.cpython-310.pyc b/tools/vqgan/__pycache__/inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c1dcf930fc6ffd847ec654a1ba699c499ec1963
Binary files /dev/null and b/tools/vqgan/__pycache__/inference.cpython-310.pyc differ
diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py
new file mode 100644
index 0000000000000000000000000000000000000000..d24a5f39566c47ea0cb1fc506d463e9c95c3efbc
--- /dev/null
+++ b/tools/vqgan/create_train_split.py
@@ -0,0 +1,83 @@
+import math
+from pathlib import Path
+from random import Random
+
+import click
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+
+@click.command()
+@click.argument("root", type=click.Path(exists=True, path_type=Path))
+@click.option("--val-ratio", type=float, default=None)
+@click.option("--val-count", type=int, default=None)
+@click.option("--filelist", default=None, type=Path)
+@click.option("--min-duration", default=None, type=float)
+@click.option("--max-duration", default=None, type=float)
+def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
+
+ if min_duration is None and max_duration is None:
+ filtered_files = list(map(str, [file.relative_to(root) for file in files]))
+ else:
+ filtered_files = []
+ for file in tqdm(files):
+ try:
+ audio = AudioSegment.from_file(str(file))
+ duration = len(audio) / 1000.0
+
+ if min_duration is not None and duration < min_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
+ )
+ continue
+
+ if max_duration is not None and duration > max_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
+ )
+ continue
+
+ filtered_files.append(str(file.relative_to(root)))
+ except Exception as e:
+ logger.info(f"Error processing {file}: {e}")
+
+ logger.info(
+ f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
+ )
+
+ Random(42).shuffle(filtered_files)
+
+ if val_count is None and val_ratio is None:
+ logger.info("Validation ratio and count not specified, using min(20%, 100)")
+ val_size = min(100, math.ceil(len(filtered_files) * 0.2))
+ elif val_count is not None and val_ratio is not None:
+ logger.error("Cannot specify both val_count and val_ratio")
+ return
+ elif val_count is not None:
+ if val_count < 1 or val_count > len(filtered_files):
+ logger.error("val_count must be between 1 and number of files")
+ return
+ val_size = val_count
+ else:
+ val_size = math.ceil(len(filtered_files) * val_ratio)
+
+ logger.info(f"Using {val_size} files for validation")
+
+ with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[val_size:]))
+
+ with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[:val_size]))
+
+ logger.info("Done")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b2be2e671be1746ecf0ba95676aeaff83f5957f
--- /dev/null
+++ b/tools/vqgan/extract_vq.py
@@ -0,0 +1,233 @@
+import os
+import subprocess as sp
+import sys
+import time
+from datetime import timedelta
+from functools import lru_cache
+from pathlib import Path
+from random import Random
+
+import click
+import numpy as np
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from lightning import LightningModule
+from loguru import logger
+from omegaconf import OmegaConf
+
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+# This file is used to convert the audio files to text files using the Whisper model.
+# It's mainly used to generate the training data for the VQ model.
+
+backends = torchaudio.list_audio_backends()
+
+if "ffmpeg" in backends:
+ backend = "ffmpeg"
+else:
+ backend = "soundfile"
+
+RANK = int(os.environ.get("SLURM_PROCID", 0))
+WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
+
+logger_format = (
+ "{time:YYYY-MM-DD HH:mm:ss.SSS} | "
+ "{level: <8} | "
+ "{name} :{function} :{line} | "
+ "{extra[rank]} - {message} "
+)
+logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
+logger.remove()
+logger.add(sys.stderr, format=logger_format)
+
+
+@lru_cache(maxsize=1)
+def get_model(
+ config_name: str = "firefly_gan_vq",
+ checkpoint_path: str = "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ device: str | torch.device = "cuda",
+):
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path,
+ map_location=device,
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ model.load_state_dict(state_dict, strict=False)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model")
+ return model
+
+
+@torch.inference_mode()
+def process_batch(files: list[Path], model) -> float:
+ wavs = []
+ audio_lengths = []
+ new_files = []
+ max_length = total_time = 0
+
+ for file in files:
+ try:
+ wav, sr = torchaudio.load(
+ str(file), backend=backend
+ ) # Need to install libsox-dev
+ except Exception as e:
+ logger.error(f"Error reading {file}: {e}")
+ continue
+
+ if wav.shape[0] > 1:
+ wav = wav.mean(dim=0, keepdim=True)
+
+ wav = torchaudio.functional.resample(
+ wav.cuda(), sr, model.spec_transform.sample_rate
+ )[0]
+ total_time += len(wav) / model.spec_transform.sample_rate
+ max_length = max(max_length, len(wav))
+
+ wavs.append(wav)
+ audio_lengths.append(len(wav))
+ new_files.append(file)
+
+ files = new_files
+
+ # Pad to max length
+ for i, wav in enumerate(wavs):
+ wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
+
+ audios = torch.stack(wavs, dim=0)[:, None]
+ audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
+
+ # Calculate lengths
+ indices, feature_lengths = model.encode(audios, audio_lengths)
+
+ # Save to disk
+ outputs = indices.cpu().numpy()
+
+ for file, length, feature, audio_length in zip(
+ files, feature_lengths, outputs, audio_lengths
+ ):
+ feature = feature[:, :length]
+
+ # (T,)
+ with open(file.with_suffix(".npy"), "wb") as f:
+ np.save(f, feature)
+
+ return total_time
+
+
+@click.command()
+@click.argument("folder")
+@click.option("--num-workers", default=1)
+@click.option("--config-name", default="firefly_gan_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+)
+@click.option("--batch-size", default=64)
+@click.option("--filelist", default=None, type=Path)
+def main(
+ folder: str,
+ num_workers: int,
+ config_name: str,
+ checkpoint_path: str,
+ batch_size: int,
+ filelist: Path,
+):
+ if num_workers > 1 and WORLD_SIZE != num_workers:
+ assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
+
+ logger.info(f"Spawning {num_workers} workers")
+
+ if torch.cuda.is_available():
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ if visible_devices is None:
+ visible_devices = list(range(torch.cuda.device_count()))
+ else:
+ visible_devices = visible_devices.split(",")
+ else:
+ # Set to empty string to avoid using GPU
+ visible_devices = [""]
+
+ processes = []
+ for i in range(num_workers):
+ env = os.environ.copy()
+ env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
+ env["SLURM_PROCID"] = str(i)
+ env["SLURM_NTASKS"] = str(num_workers)
+
+ processes.append(
+ sp.Popen(
+ [sys.executable] + sys.argv.copy(),
+ env=env,
+ )
+ )
+
+ for p in processes:
+ p.wait()
+
+ logger.info(f"All workers finished")
+ return
+
+ # This is a worker
+ logger.info(f"Starting worker")
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
+
+ print(f"Found {len(files)} files")
+ files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
+
+ total_files = len(files)
+ files = files[RANK::WORLD_SIZE]
+ logger.info(f"Processing {len(files)}/{total_files} files")
+
+ # Batch processing
+ total_time = 0
+ begin_time = time.time()
+ processed_files = 0
+ model = get_model(config_name, checkpoint_path)
+
+ for n_batch, idx in enumerate(range(0, len(files), batch_size)):
+ batch = files[idx : idx + batch_size]
+ batch_time = process_batch(batch, model)
+
+ total_time += batch_time
+ processed_files += len(batch)
+
+ if (n_batch + 1) % 10 == 0:
+ eta = (
+ (time.time() - begin_time)
+ / processed_files
+ * (len(files) - processed_files)
+ )
+ logger.info(
+ f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+ + f"ETA: {timedelta(seconds=round(eta))}s"
+ )
+
+ logger.info(
+ f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/inference.py b/tools/vqgan/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a63165045fb28910307a34b1930e5770c5316bd
--- /dev/null
+++ b/tools/vqgan/inference.py
@@ -0,0 +1,121 @@
+from pathlib import Path
+
+import click
+import hydra
+import numpy as np
+import soundfile as sf
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+from omegaconf import OmegaConf
+
+from tools.file import AUDIO_EXTENSIONS
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+
+def load_model(config_name, checkpoint_path, device="cuda"):
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path, map_location=device, mmap=True, weights_only=True
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ result = model.load_state_dict(state_dict, strict=False, assign=True)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model: {result}")
+ return model
+
+
+@torch.no_grad()
+@click.command()
+@click.option(
+ "--input-path",
+ "-i",
+ default="test.wav",
+ type=click.Path(exists=True, path_type=Path),
+)
+@click.option(
+ "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
+)
+@click.option("--config-name", default="firefly_gan_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+)
+@click.option(
+ "--device",
+ "-d",
+ default="cuda",
+)
+def main(input_path, output_path, config_name, checkpoint_path, device):
+ model = load_model(config_name, checkpoint_path, device=device)
+
+ if input_path.suffix in AUDIO_EXTENSIONS:
+ logger.info(f"Processing in-place reconstruction of {input_path}")
+
+ # Load audio
+ audio, sr = torchaudio.load(str(input_path))
+ if audio.shape[0] > 1:
+ audio = audio.mean(0, keepdim=True)
+ audio = torchaudio.functional.resample(
+ audio, sr, model.spec_transform.sample_rate
+ )
+
+ audios = audio[None].to(device)
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
+ indices = model.encode(audios, audio_lengths)[0][0]
+
+ logger.info(f"Generated indices of shape {indices.shape}")
+
+ # Save indices
+ np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
+ elif input_path.suffix == ".npy":
+ logger.info(f"Processing precomputed indices from {input_path}")
+ indices = np.load(input_path)
+ indices = torch.from_numpy(indices).to(device).long()
+ assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
+ else:
+ raise ValueError(f"Unknown input type: {input_path}")
+
+ # Restore
+ feature_lengths = torch.tensor([indices.shape[1]], device=device)
+ fake_audios, _ = model.decode(
+ indices=indices[None], feature_lengths=feature_lengths
+ )
+ audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
+
+ logger.info(
+ f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
+ )
+
+ # Save audio
+ fake_audio = fake_audios[0, 0].float().cpu().numpy()
+ sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
+ logger.info(f"Saved audio to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/webui/__init__.py b/tools/webui/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..78dbc343f418149735c16c7b5d33b247c7b5411b
--- /dev/null
+++ b/tools/webui/__init__.py
@@ -0,0 +1,173 @@
+from typing import Callable
+
+import gradio as gr
+
+from fish_speech.i18n import i18n
+from tools.inference_engine.utils import normalize_text
+from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER
+
+
+def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
+ with gr.Blocks(theme=gr.themes.Base()) as app:
+ gr.Markdown(HEADER_MD)
+
+ # Use light theme by default
+ app.load(
+ None,
+ None,
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
+ % theme,
+ )
+
+ # Inference
+ with gr.Row():
+ with gr.Column(scale=3):
+ text = gr.Textbox(
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
+ )
+ refined_text = gr.Textbox(
+ label=i18n("Realtime Transform Text"),
+ placeholder=i18n(
+ "Normalization Result Preview (Currently Only Chinese)"
+ ),
+ lines=5,
+ interactive=False,
+ )
+
+ with gr.Row():
+ normalize = gr.Checkbox(
+ label=i18n("Text Normalization"),
+ value=False,
+ )
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab(label=i18n("Advanced Config")):
+ with gr.Row():
+ chunk_length = gr.Slider(
+ label=i18n("Iterative Prompt Length, 0 means off"),
+ minimum=0,
+ maximum=300,
+ value=200,
+ step=8,
+ )
+
+ max_new_tokens = gr.Slider(
+ label=i18n(
+ "Maximum tokens per batch, 0 means no limit"
+ ),
+ minimum=0,
+ maximum=2048,
+ value=0,
+ step=8,
+ )
+
+ with gr.Row():
+ top_p = gr.Slider(
+ label="Top-P",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+
+ repetition_penalty = gr.Slider(
+ label=i18n("Repetition Penalty"),
+ minimum=1,
+ maximum=1.5,
+ value=1.2,
+ step=0.01,
+ )
+
+ with gr.Row():
+ temperature = gr.Slider(
+ label="Temperature",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+ seed = gr.Number(
+ label="Seed",
+ info="0 means randomized inference, otherwise deterministic",
+ value=0,
+ )
+
+ with gr.Tab(label=i18n("Reference Audio")):
+ with gr.Row():
+ gr.Markdown(
+ i18n(
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
+ )
+ )
+ with gr.Row():
+ reference_id = gr.Textbox(
+ label=i18n("Reference ID"),
+ placeholder="Leave empty to use uploaded references",
+ )
+
+ with gr.Row():
+ use_memory_cache = gr.Radio(
+ label=i18n("Use Memory Cache"),
+ choices=["on", "off"],
+ value="on",
+ )
+
+ with gr.Row():
+ reference_audio = gr.Audio(
+ label=i18n("Reference Audio"),
+ type="filepath",
+ )
+ with gr.Row():
+ reference_text = gr.Textbox(
+ label=i18n("Reference Text"),
+ lines=1,
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+ value="",
+ )
+
+ with gr.Column(scale=3):
+ with gr.Row():
+ error = gr.HTML(
+ label=i18n("Error Message"),
+ visible=True,
+ )
+ with gr.Row():
+ audio = gr.Audio(
+ label=i18n("Generated Audio"),
+ type="numpy",
+ interactive=False,
+ visible=True,
+ )
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ generate = gr.Button(
+ value="\U0001F3A7 " + i18n("Generate"),
+ variant="primary",
+ )
+
+ text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
+
+ # Submit
+ generate.click(
+ inference_fct,
+ [
+ refined_text,
+ normalize,
+ reference_id,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ use_memory_cache,
+ ],
+ [audio, error],
+ concurrency_limit=1,
+ )
+
+ return app
diff --git a/tools/webui/__pycache__/__init__.cpython-310.pyc b/tools/webui/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..999ec585a0d746f2f733401052dc62deeb7154d9
Binary files /dev/null and b/tools/webui/__pycache__/__init__.cpython-310.pyc differ
diff --git a/tools/webui/__pycache__/__init__.cpython-38.pyc b/tools/webui/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a83f79fd8f543b6091ae67490c04134fa2db0bcc
Binary files /dev/null and b/tools/webui/__pycache__/__init__.cpython-38.pyc differ
diff --git a/tools/webui/__pycache__/inference.cpython-310.pyc b/tools/webui/__pycache__/inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dabcc814a25e8f309f261e0cccf698ba1a6acd5b
Binary files /dev/null and b/tools/webui/__pycache__/inference.cpython-310.pyc differ
diff --git a/tools/webui/__pycache__/variables.cpython-310.pyc b/tools/webui/__pycache__/variables.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..223079322cd5c8e8a3fb7bff90ff5b98b51f129c
Binary files /dev/null and b/tools/webui/__pycache__/variables.cpython-310.pyc differ
diff --git a/tools/webui/inference.py b/tools/webui/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea3553be7784c0cda1b6e75b87d8d06cc9c92699
--- /dev/null
+++ b/tools/webui/inference.py
@@ -0,0 +1,91 @@
+import html
+from functools import partial
+from typing import Any, Callable
+
+from fish_speech.i18n import i18n
+from tools.schema import ServeReferenceAudio, ServeTTSRequest
+
+
+def inference_wrapper(
+ text,
+ normalize,
+ reference_id,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ use_memory_cache,
+ engine,
+):
+ """
+ Wrapper for the inference function.
+ Used in the Gradio interface.
+ """
+
+ if reference_audio:
+ references = get_reference_audio(reference_audio, reference_text)
+ else:
+ references = []
+
+ req = ServeTTSRequest(
+ text=text,
+ normalize=normalize,
+ reference_id=reference_id if reference_id else None,
+ references=references,
+ max_new_tokens=max_new_tokens,
+ chunk_length=chunk_length,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ seed=int(seed) if seed else None,
+ use_memory_cache=use_memory_cache,
+ )
+
+ for result in engine.inference(req):
+ match result.code:
+ case "final":
+ return result.audio, None
+ case "error":
+ return None, build_html_error_message(i18n(result.error))
+ case _:
+ pass
+
+ return None, i18n("No audio generated")
+
+
+def get_reference_audio(reference_audio: str, reference_text: str) -> list:
+ """
+ Get the reference audio bytes.
+ """
+
+ with open(reference_audio, "rb") as audio_file:
+ audio_bytes = audio_file.read()
+
+ return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
+
+
+def build_html_error_message(error: Any) -> str:
+
+ error = error if isinstance(error, Exception) else Exception("Unknown error")
+
+ return f"""
+
+ {html.escape(str(error))}
+
+ """
+
+
+def get_inference_wrapper(engine) -> Callable:
+ """
+ Get the inference function with the immutable arguments.
+ """
+
+ return partial(
+ inference_wrapper,
+ engine=engine,
+ )
diff --git a/tools/webui/variables.py b/tools/webui/variables.py
new file mode 100644
index 0000000000000000000000000000000000000000..db42d5d797e821e9a34832bea4344ecb726c97db
--- /dev/null
+++ b/tools/webui/variables.py
@@ -0,0 +1,14 @@
+from fish_speech.i18n import i18n
+
+HEADER_MD = f"""# Fish Speech
+
+{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
+
+{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
+
+{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
+
+{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
+"""
+
+TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..42e7de8a185880d3f2afd368d6df3429488465a4
--- /dev/null
+++ b/tools/whisper_asr.py
@@ -0,0 +1,176 @@
+"""
+Used to transcribe all audio files in one folder into another folder.
+e.g.
+Directory structure:
+--pre_data_root
+----SP_1
+------01.wav
+------02.wav
+------......
+----SP_2
+------01.wav
+------02.wav
+------......
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
+to transcribe the first speaker.
+
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
+to transcribe the second speaker.
+
+Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
+"""
+
+import re
+from pathlib import Path
+
+import click
+import soundfile as sf
+from faster_whisper import WhisperModel
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files
+
+
+@click.command()
+@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
+@click.option(
+ "--compute-type",
+ default="float16",
+ help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
+)
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option(
+ "--sample-rate",
+ default=44100,
+ type=int,
+ help="Output sample rate, default to input sample rate",
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
+def main(
+ model_size,
+ compute_type,
+ audio_dir,
+ save_dir,
+ sample_rate,
+ device,
+ language,
+ initial_prompt,
+):
+ logger.info("Loading / Downloading Faster Whisper model...")
+
+ model = WhisperModel(
+ model_size,
+ device=device,
+ compute_type=compute_type,
+ download_root="faster_whisper",
+ )
+
+ logger.info("Model loaded.")
+
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ audio = AudioSegment.from_file(file_path)
+
+ segments, info = model.transcribe(
+ file_path,
+ beam_size=5,
+ language=None if language == "auto" else language,
+ initial_prompt=initial_prompt,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ whole_text = None
+ for segment in segments:
+ id, start, end, text = (
+ segment.id,
+ segment.start,
+ segment.end,
+ segment.text,
+ )
+ print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
+ if not whole_text:
+ whole_text = text
+ else:
+ whole_text += ", " + text
+
+ whole_text += "."
+
+ audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
+ audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}")
+
+ transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(whole_text)
+
+
+if __name__ == "__main__":
+ main()
+ exit(0)
+
+ audio = AudioSegment.from_wav(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
+ )
+
+ model_size = "large-v3"
+
+ model = WhisperModel(
+ model_size,
+ device="cuda",
+ compute_type="float16",
+ download_root="faster_whisper",
+ )
+
+ segments, info = model.transcribe(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
+ beam_size=5,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ for i, segment in enumerate(segments):
+ print(
+ "Segment %03d [%.2fs -> %.2fs] %s"
+ % (i, segment.start, segment.end, segment.text)
+ )
+ start_ms = int(segment.start * 1000)
+ end_ms = int(segment.end * 1000)
+ segment_audio = audio[start_ms:end_ms]
+ segment_audio.export(f"segment_{i:03d}.wav", format="wav")
+ print(f"Exported segment_{i:03d}.wav")
+
+ print("All segments have been exported.")