|
import gradio as gr |
|
import torch |
|
import soundfile as sf |
|
from snac import SNAC |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def find_last_instance_of_separator(lst, element=50258): |
|
reversed_list = lst[::-1] |
|
try: |
|
reversed_index = reversed_list.index(element) |
|
return len(lst) - 1 - reversed_index |
|
except ValueError: |
|
raise ValueError |
|
|
|
def reconstruct_tensors(flattened_output): |
|
def count_elements_between_hashes(lst): |
|
try: |
|
first_index = lst.index(50258) |
|
second_index = lst.index(50258, first_index + 1) |
|
return second_index - first_index - 1 |
|
except ValueError: |
|
return "List does not contain two '#' symbols" |
|
|
|
def remove_elements_before_hash(flattened_list): |
|
try: |
|
first_hash_index = flattened_list.index(50258) |
|
return flattened_list[first_hash_index:] |
|
except ValueError: |
|
return "List does not contain the symbol '#'" |
|
|
|
def list_to_torch_tensor(tensor1): |
|
tensor = torch.tensor(tensor1) |
|
tensor = tensor.unsqueeze(0) |
|
return tensor |
|
|
|
flattened_output = remove_elements_before_hash(flattened_output) |
|
last_index = find_last_instance_of_separator(flattened_output) |
|
flattened_output = flattened_output[:last_index] |
|
|
|
codes = [] |
|
tensor1 = [] |
|
tensor2 = [] |
|
tensor3 = [] |
|
tensor4 = [] |
|
|
|
n_tensors = count_elements_between_hashes(flattened_output) |
|
if n_tensors == 7: |
|
for i in range(0, len(flattened_output), 8): |
|
tensor1.append(flattened_output[i+1]) |
|
tensor2.append(flattened_output[i+2]) |
|
tensor3.append(flattened_output[i+3]) |
|
tensor3.append(flattened_output[i+4]) |
|
tensor2.append(flattened_output[i+5]) |
|
tensor3.append(flattened_output[i+6]) |
|
tensor3.append(flattened_output[i+7]) |
|
codes = [list_to_torch_tensor(tensor1).to(device), list_to_torch_tensor(tensor2).to(device), list_to_torch_tensor(tensor3).to(device)] |
|
|
|
if n_tensors == 15: |
|
for i in range(0, len(flattened_output), 16): |
|
tensor1.append(flattened_output[i+1]) |
|
tensor2.append(flattened_output[i+2]) |
|
tensor3.append(flattened_output[i+3]) |
|
tensor4.append(flattened_output[i+4]) |
|
tensor4.append(flattened_output[i+5]) |
|
tensor3.append(flattened_output[i+6]) |
|
tensor4.append(flattened_output[i+7]) |
|
tensor4.append(flattened_output[i+8]) |
|
tensor2.append(flattened_output[i+9]) |
|
tensor3.append(flattened_output[i+10]) |
|
tensor4.append(flattened_output[i+11]) |
|
tensor4.append(flattened_output[i+12]) |
|
tensor3.append(flattened_output[i+13]) |
|
tensor4.append(flattened_output[i+14]) |
|
tensor4.append(flattened_output[i+15]) |
|
codes = [list_to_torch_tensor(tensor1).to(device), list_to_torch_tensor(tensor2).to(device), list_to_torch_tensor(tensor3).to(device), list_to_torch_tensor(tensor4).to(device)] |
|
|
|
return codes |
|
|
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("Lwasinam/voicera-jenny-finetune") |
|
model = AutoModelForCausalLM.from_pretrained("Lwasinam/voicera-jenny-finetune").to(device) |
|
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device) |
|
return model, tokenizer, snac_model |
|
|
|
def SpeechDecoder(codes, snac_model): |
|
codes = codes.squeeze(0).tolist() |
|
reconstructed_codes = reconstruct_tensors(codes) |
|
audio_hat = snac_model.decode(reconstructed_codes) |
|
audio_path = "reconstructed_audio.wav" |
|
sf.write(audio_path, audio_hat.squeeze().cpu().detach().numpy(), 24000) |
|
return audio_path |
|
|
|
def generate_audio(text, tokenizer, model, snac_model): |
|
output_codes = [] |
|
with torch.no_grad(): |
|
input_text = text |
|
input_ids = tokenizer(input_text, return_tensors='pt').to(device) |
|
output_codes = model.generate(input_ids['input_ids'], attention_mask=input_ids['attention_mask'], max_length=1024, |
|
num_beams=5, top_p=0.95, temperature=0.8, do_sample=True, repetition_penalty=2.0) |
|
audio_path = SpeechDecoder(output_codes, snac_model) |
|
return audio_path |
|
|
|
def main(text): |
|
model, tokenizer, snac_model = load_model() |
|
audio_path = generate_audio(text, tokenizer, model, snac_model) |
|
return audio_path |
|
|
|
|
|
iface = gr.Interface( |
|
fn=main, |
|
inputs='textbox', |
|
outputs="audio", |
|
title="Voicera TTS", |
|
description="Generate speech from text using Voicera TTS model." |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|