parler-tts-ie / handler.py
sergeipetrov's picture
Update handler.py
9647d8d verified
raw
history blame contribute delete
No virus
1.47 kB
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, set_seed
import soundfile as sf
import base64
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.model = ParlerTTSForConditionalGeneration.from_pretrained(
"parler-tts/parler-tts-mini-expresso",
torch_dtype=torch.float16
).to(self.device)
# self.model.forward = torch.compile(self.model.forward, mode="reduce-overhead", fullgraph=True)
self.tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
def __call__(self, data):
inputs = data["inputs"]
prompt = inputs["prompt"]
description = inputs["description"]
input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device)
prompt_input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
set_seed(42)
generation = self.model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, self.model.config.sampling_rate)
with open("parler_tts_out.wav", "rb") as f:
return base64.b64encode(f.read()).decode()