Blakus commited on
Commit
f349162
·
verified ·
1 Parent(s): 8f0bc36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -148
app.py CHANGED
@@ -1,148 +0,0 @@
1
- import sys
2
- import io, os, stat
3
- import subprocess
4
- import random
5
- from zipfile import ZipFile
6
- import uuid
7
- import time
8
- import torch
9
- import torchaudio
10
- import time
11
- # Mantenemos la descarga de MeCab
12
- os.system('python -m unidic download')
13
-
14
- # Mantenemos el acuerdo de CPML
15
- os.environ["COQUI_TOS_AGREED"] = "1"
16
-
17
- import langid
18
- import base64
19
- import csv
20
- from io import StringIO
21
- import datetime
22
- import re
23
-
24
- import gradio as gr
25
- from scipy.io.wavfile import write
26
- from pydub import AudioSegment
27
-
28
- from TTS.api import TTS
29
- from TTS.tts.configs.xtts_config import XttsConfig
30
- from TTS.tts.models.xtts import Xtts
31
- from TTS.utils.generic_utils import get_user_data_dir
32
-
33
- HF_TOKEN = os.environ.get("HF_TOKEN")
34
-
35
- from huggingface_hub import hf_hub_download
36
- import os
37
- from TTS.utils.manage import get_user_data_dir
38
-
39
- # Mantenemos la autenticación y descarga del modelo
40
- repo_id = "Blakus/Pedro_Lab_XTTS"
41
- local_dir = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2")
42
- os.makedirs(local_dir, exist_ok=True)
43
- files_to_download = ["config.json", "model.pth", "vocab.json"]
44
- for file_name in files_to_download:
45
- print(f"Downloading {file_name} from {repo_id}")
46
- local_file_path = os.path.join(local_dir, file_name)
47
- hf_hub_download(repo_id=repo_id, filename=file_name, local_dir=local_dir)
48
-
49
- # Cargamos configuración y modelo
50
- config_path = os.path.join(local_dir, "config.json")
51
- checkpoint_path = os.path.join(local_dir, "model.pth")
52
- vocab_path = os.path.join(local_dir, "vocab.json")
53
-
54
- config = XttsConfig()
55
- config.load_json(config_path)
56
-
57
- model = Xtts.init_from_config(config)
58
- model.load_checkpoint(config, checkpoint_path=checkpoint_path, vocab_path=vocab_path, eval=True, use_deepspeed=False)
59
-
60
- print("Modelo cargado en CPU")
61
-
62
- # Mantenemos variables globales y funciones auxiliares
63
- DEVICE_ASSERT_DETECTED = 0
64
- DEVICE_ASSERT_PROMPT = None
65
- DEVICE_ASSERT_LANG = None
66
- supported_languages = config.languages
67
-
68
- # Función de inferencia usando parámetros predeterminados del archivo de configuración
69
- def predict(prompt, language, audio_file_pth, mic_file_path, use_mic):
70
- try:
71
- if use_mic:
72
- speaker_wav = mic_file_path
73
- else:
74
- speaker_wav = audio_file_pth
75
-
76
- if len(prompt) < 2 or len(prompt) > 200:
77
- return None, None, "El texto debe tener entre 2 y 200 caracteres."
78
-
79
- # Usamos los valores de la configuración directamente
80
- temperature = getattr(config, "temperature", 0.75)
81
- repetition_penalty = getattr(config, "repetition_penalty", 5.0)
82
- gpt_cond_len = getattr(config, "gpt_cond_len", 30)
83
- gpt_cond_chunk_len = getattr(config, "gpt_cond_chunk_len", 4)
84
- max_ref_length = getattr(config, "max_ref_len", 60)
85
-
86
- gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
87
- audio_path=speaker_wav,
88
- gpt_cond_len=gpt_cond_len,
89
- gpt_cond_chunk_len=gpt_cond_chunk_len,
90
- max_ref_length=max_ref_length
91
- )
92
-
93
- # Medimos el tiempo de inferencia manualmente
94
- start_time = time.time()
95
- out = model.inference(
96
- prompt,
97
- language,
98
- gpt_cond_latent,
99
- speaker_embedding,
100
- temperature=temperature,
101
- repetition_penalty=repetition_penalty,
102
- )
103
- inference_time = time.time() - start_time
104
-
105
- torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
106
-
107
- # Calculamos las métricas usando el tiempo medido manualmente
108
- audio_length = len(out["wav"]) / 24000 # duración del audio en segundos
109
- real_time_factor = inference_time / audio_length
110
-
111
- metrics_text = f"Tiempo de generación: {inference_time:.2f} segundos\n"
112
- metrics_text += f"Factor de tiempo real: {real_time_factor:.2f}"
113
-
114
- return gr.make_waveform("output.wav"), "output.wav", metrics_text
115
-
116
- except Exception as e:
117
- print(f"Error detallado: {str(e)}")
118
- return None, None, f"Error: {str(e)}"
119
-
120
-
121
- # Interfaz de Gradio actualizada sin sliders
122
- with gr.Blocks(theme=gr.themes.Base()) as demo:
123
- gr.Markdown("# Sintetizador de Voz XTTS")
124
-
125
- with gr.Row():
126
- with gr.Column():
127
- input_text = gr.Textbox(label="Texto a sintetizar", placeholder="Escribe aquí el texto que quieres convertir a voz...")
128
- language = gr.Dropdown(label="Idioma", choices=supported_languages, value="es")
129
- audio_file = gr.Audio(label="Audio de referencia", type="filepath")
130
- use_mic = gr.Checkbox(label="Usar micrófono")
131
- mic_file = gr.Audio(label="Grabar con micrófono", source="microphone", type="filepath", visible=False)
132
-
133
- use_mic.change(fn=lambda x: gr.update(visible=x), inputs=[use_mic], outputs=[mic_file])
134
-
135
- generate_button = gr.Button("Generar voz")
136
-
137
- with gr.Column():
138
- output_audio = gr.Audio(label="Audio generado")
139
- waveform = gr.Image(label="Forma de onda")
140
- metrics = gr.Textbox(label="Métricas")
141
-
142
- generate_button.click(
143
- predict,
144
- inputs=[input_text, language, audio_file, mic_file, use_mic],
145
- outputs=[waveform, output_audio, metrics]
146
- )
147
-
148
- demo.launch(debug=True)