fish-audio-t / tools /server /api_utils.py
kiylu's picture
add project files
b128c76
from argparse import ArgumentParser
from http import HTTPStatus
from typing import Annotated, Any
import ormsgpack
from baize.datastructures import ContentType
from kui.asgi import HTTPException, HttpRequest
from tools.inference_engine import TTSInferenceEngine
from tools.schema import ServeTTSRequest
from tools.server.inference import inference_wrapper as inference
def parse_args():
parser = ArgumentParser()
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
parser.add_argument("--load-asr-model", action="store_true")
parser.add_argument(
"--llama-checkpoint-path",
type=str,
default="checkpoints/fish-speech-1.5",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=str,
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-text-length", type=int, default=0)
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
parser.add_argument("--workers", type=int, default=1)
return parser.parse_args()
class MsgPackRequest(HttpRequest):
async def data(
self,
) -> Annotated[
Any, ContentType("application/msgpack"), ContentType("application/json")
]:
if self.content_type == "application/msgpack":
return ormsgpack.unpackb(await self.body)
elif self.content_type == "application/json":
return await self.json
raise HTTPException(
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
headers={"Accept": "application/msgpack, application/json"},
)
async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
for chunk in inference(req, engine):
if isinstance(chunk, bytes):
yield chunk
async def buffer_to_async_generator(buffer):
yield buffer
def get_content_type(audio_format):
if audio_format == "wav":
return "audio/wav"
elif audio_format == "flac":
return "audio/flac"
elif audio_format == "mp3":
return "audio/mpeg"
else:
return "application/octet-stream"