from argparse import ArgumentParser from http import HTTPStatus from typing import Annotated, Any import ormsgpack from baize.datastructures import ContentType from kui.asgi import HTTPException, HttpRequest from tools.inference_engine import TTSInferenceEngine from tools.schema import ServeTTSRequest from tools.server.inference import inference_wrapper as inference def parse_args(): parser = ArgumentParser() parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") parser.add_argument("--load-asr-model", action="store_true") parser.add_argument( "--llama-checkpoint-path", type=str, default="checkpoints/fish-speech-1.5", ) parser.add_argument( "--decoder-checkpoint-path", type=str, default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", ) parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--half", action="store_true") parser.add_argument("--compile", action="store_true") parser.add_argument("--max-text-length", type=int, default=0) parser.add_argument("--listen", type=str, default="127.0.0.1:8080") parser.add_argument("--workers", type=int, default=1) return parser.parse_args() class MsgPackRequest(HttpRequest): async def data( self, ) -> Annotated[ Any, ContentType("application/msgpack"), ContentType("application/json") ]: if self.content_type == "application/msgpack": return ormsgpack.unpackb(await self.body) elif self.content_type == "application/json": return await self.json raise HTTPException( HTTPStatus.UNSUPPORTED_MEDIA_TYPE, headers={"Accept": "application/msgpack, application/json"}, ) async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine): for chunk in inference(req, engine): if isinstance(chunk, bytes): yield chunk async def buffer_to_async_generator(buffer): yield buffer def get_content_type(audio_format): if audio_format == "wav": return "audio/wav" elif audio_format == "flac": return "audio/flac" elif audio_format == "mp3": return "audio/mpeg" else: return "application/octet-stream"