Spaces:
Running
Running
from argparse import ArgumentParser | |
from http import HTTPStatus | |
from typing import Annotated, Any | |
import ormsgpack | |
from baize.datastructures import ContentType | |
from kui.asgi import HTTPException, HttpRequest | |
from tools.inference_engine import TTSInferenceEngine | |
from tools.schema import ServeTTSRequest | |
from tools.server.inference import inference_wrapper as inference | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") | |
parser.add_argument("--load-asr-model", action="store_true") | |
parser.add_argument( | |
"--llama-checkpoint-path", | |
type=str, | |
default="checkpoints/fish-speech-1.5", | |
) | |
parser.add_argument( | |
"--decoder-checkpoint-path", | |
type=str, | |
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", | |
) | |
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") | |
parser.add_argument("--device", type=str, default="cuda") | |
parser.add_argument("--half", action="store_true") | |
parser.add_argument("--compile", action="store_true") | |
parser.add_argument("--max-text-length", type=int, default=0) | |
parser.add_argument("--listen", type=str, default="127.0.0.1:8080") | |
parser.add_argument("--workers", type=int, default=1) | |
return parser.parse_args() | |
class MsgPackRequest(HttpRequest): | |
async def data( | |
self, | |
) -> Annotated[ | |
Any, ContentType("application/msgpack"), ContentType("application/json") | |
]: | |
if self.content_type == "application/msgpack": | |
return ormsgpack.unpackb(await self.body) | |
elif self.content_type == "application/json": | |
return await self.json | |
raise HTTPException( | |
HTTPStatus.UNSUPPORTED_MEDIA_TYPE, | |
headers={"Accept": "application/msgpack, application/json"}, | |
) | |
async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine): | |
for chunk in inference(req, engine): | |
if isinstance(chunk, bytes): | |
yield chunk | |
async def buffer_to_async_generator(buffer): | |
yield buffer | |
def get_content_type(audio_format): | |
if audio_format == "wav": | |
return "audio/wav" | |
elif audio_format == "flac": | |
return "audio/flac" | |
elif audio_format == "mp3": | |
return "audio/mpeg" | |
else: | |
return "application/octet-stream" | |