Spaces:
Running
on
Zero
Running
on
Zero
import base64 | |
import io | |
import os | |
import tempfile | |
import wave | |
import torch | |
import numpy as np | |
from typing import List | |
from pydantic import BaseModel | |
import spaces | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
from trainer.io import get_user_data_dir | |
from TTS.utils.manage import ModelManager | |
os.environ["COQUI_TOS_AGREED"] = "1" | |
torch.set_num_threads(int(os.environ.get("NUM_THREADS", os.cpu_count()))) | |
device = torch.device("cuda" if os.environ.get("USE_CPU", "0") == "0" else "cpu") | |
if not torch.cuda.is_available() and device == "cuda": | |
raise RuntimeError("CUDA device unavailable, please use Dockerfile.cpu instead.") | |
custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models") | |
if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"): | |
model_path = custom_model_path | |
print("Loading custom model from", model_path, flush=True) | |
else: | |
print("Loading default model", flush=True) | |
model_name = "tts_models/multilingual/multi-dataset/xtts_v2" | |
print("Downloading XTTS Model:", model_name, flush=True) | |
ModelManager().download_model(model_name) | |
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--")) | |
print("XTTS Model downloaded", flush=True) | |
print("Loading XTTS", flush=True) | |
config = XttsConfig() | |
config.load_json(os.path.join(model_path, "config.json")) | |
model = Xtts.init_from_config(config) | |
model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=True if device == "cuda" else False) | |
model.to(device) | |
print("XTTS Loaded.", flush=True) | |
print("Running XTTS Server ...", flush=True) | |
# @app.post("/clone_speaker") | |
def predict_speaker(wav_file): | |
"""Compute conditioning inputs from reference audio file.""" | |
if isinstance(wav_file, str): | |
wav_file = open(wav_file,"rb"); | |
temp_audio_name = next(tempfile._get_candidate_names()) | |
with open(temp_audio_name, "wb") as temp, torch.inference_mode(): | |
temp.write(io.BytesIO(wav_file.read()).getbuffer()) | |
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( | |
temp_audio_name | |
) | |
return { | |
"gpt_cond_latent": gpt_cond_latent.cpu().squeeze().half().tolist(), | |
"speaker_embedding": speaker_embedding.cpu().squeeze().half().tolist(), | |
} | |
def postprocess(wav): | |
"""Post process the output waveform""" | |
if isinstance(wav, list): | |
wav = torch.cat(wav, dim=0) | |
wav = wav.clone().detach().cpu().numpy() | |
wav = wav[None, : int(wav.shape[0])] | |
wav = np.clip(wav, -1, 1) | |
wav = (wav * 32767).astype(np.int16) | |
return wav | |
def encode_audio_common( | |
frame_input, encode_base64=True, sample_rate=24000, sample_width=2, channels=1 | |
): | |
"""Return base64 encoded audio""" | |
wav_buf = io.BytesIO() | |
with wave.open(wav_buf, "wb") as vfout: | |
vfout.setnchannels(channels) | |
vfout.setsampwidth(sample_width) | |
vfout.setframerate(sample_rate) | |
vfout.writeframes(frame_input) | |
wav_buf.seek(0) | |
if encode_base64: | |
b64_encoded = base64.b64encode(wav_buf.getbuffer()).decode("utf-8") | |
return b64_encoded | |
else: | |
return wav_buf.read() | |
class StreamingInputs(BaseModel): | |
speaker_embedding: List[float] | |
gpt_cond_latent: List[List[float]] | |
text: str | |
language: str | |
add_wav_header: bool = True | |
stream_chunk_size: str = "20" | |
# | |
#def predict_streaming_generator(parsed_input: dict = Body(...)): | |
# speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) | |
# gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) | |
# text = parsed_input.text | |
# language = parsed_input.language | |
# | |
# stream_chunk_size = int(parsed_input.stream_chunk_size) | |
# add_wav_header = parsed_input.add_wav_header | |
# | |
# | |
# chunks = model.inference_stream( | |
# text, | |
# language, | |
# gpt_cond_latent, | |
# speaker_embedding, | |
# stream_chunk_size=stream_chunk_size, | |
# enable_text_splitting=True | |
# ) | |
# | |
# for i, chunk in enumerate(chunks): | |
# chunk = postprocess(chunk) | |
# if i == 0 and add_wav_header: | |
# yield encode_audio_common(b"", encode_base64=False) | |
# yield chunk.tobytes() | |
# else: | |
# yield chunk.tobytes() | |
# | |
# | |
## @app.post("/tts_stream") | |
#def predict_streaming_endpoint(parsed_input: StreamingInputs): | |
# return StreamingResponse( | |
# predict_streaming_generator(parsed_input), | |
# media_type="audio/wav", | |
# ) | |
class TTSInputs(BaseModel): | |
speaker_embedding: List[float] | |
gpt_cond_latent: List[List[float]] | |
text: str | |
language: str | |
temperature: float | |
speed: float | |
top_k: int | |
top_p: float | |
# @app.post("/tts") | |
def predict_speech(parsed_input: TTSInputs): | |
speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) | |
gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) | |
text = parsed_input.text | |
language = parsed_input.language | |
temperature = parsed_input.temperature | |
speed = parsed_input.speed | |
top_k = parsed_input.top_k | |
top_p = parsed_input.top_p | |
length_penalty = 1.0 | |
repetition_penalty= 2.0 | |
out = model.inference( | |
text, | |
language, | |
gpt_cond_latent, | |
speaker_embedding, | |
temperature, | |
length_penalty, | |
repetition_penalty, | |
top_k, | |
top_p, | |
speed, | |
) | |
wav = postprocess(torch.tensor(out["wav"])) | |
return encode_audio_common(wav.tobytes()) | |
# @app.get("/studio_speakers") | |
def get_speakers(): | |
if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"): | |
return { | |
speaker: { | |
"speaker_embedding": model.speaker_manager.speakers[speaker]["speaker_embedding"].cpu().squeeze().half().tolist(), | |
"gpt_cond_latent": model.speaker_manager.speakers[speaker]["gpt_cond_latent"].cpu().squeeze().half().tolist(), | |
} | |
for speaker in model.speaker_manager.speakers.keys() | |
} | |
else: | |
return {} | |
# @app.get("/languages") | |
def get_languages(): | |
return config.languages |