Update tts.py
Browse files
tts.py
CHANGED
@@ -1,416 +1,174 @@
|
|
1 |
import os
|
2 |
-
|
3 |
-
from fastapi.security.api_key import APIKeyHeader, APIKey
|
4 |
-
from fastapi.responses import JSONResponse
|
5 |
-
from pydantic import BaseModel
|
6 |
-
from typing import Optional
|
7 |
-
import numpy as np
|
8 |
-
import io
|
9 |
-
import soundfile as sf
|
10 |
-
import base64
|
11 |
-
import logging
|
12 |
-
import torch
|
13 |
-
import librosa
|
14 |
-
from pathlib import Path
|
15 |
-
from pydub import AudioSegment
|
16 |
-
from moviepy.editor import VideoFileClip
|
17 |
-
import traceback
|
18 |
-
from logging.handlers import RotatingFileHandler
|
19 |
-
import boto3
|
20 |
-
from botocore.exceptions import NoCredentialsError
|
21 |
-
import time
|
22 |
import tempfile
|
23 |
-
import
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
from tts import synthesize, TTS_LANGUAGES
|
28 |
-
from lid import identify
|
29 |
-
|
30 |
-
# Configure logging
|
31 |
-
logging.basicConfig(level=logging.INFO)
|
32 |
-
logger = logging.getLogger(__name__)
|
33 |
-
|
34 |
-
# Add a file handler
|
35 |
-
file_handler = RotatingFileHandler('app.log', maxBytes=10000000, backupCount=5)
|
36 |
-
file_handler.setLevel(logging.INFO)
|
37 |
-
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
38 |
-
file_handler.setFormatter(formatter)
|
39 |
-
logger.addHandler(file_handler)
|
40 |
-
|
41 |
-
app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages")
|
42 |
-
|
43 |
-
# S3 Configuration
|
44 |
-
S3_BUCKET = os.environ.get("S3_BUCKET")
|
45 |
-
S3_REGION = os.environ.get("S3_REGION")
|
46 |
-
S3_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID")
|
47 |
-
S3_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY")
|
48 |
-
|
49 |
-
# API Key Configuration
|
50 |
-
API_KEY = os.environ.get("API_KEY")
|
51 |
-
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
52 |
-
|
53 |
-
# Initialize S3 client
|
54 |
-
s3_client = boto3.client(
|
55 |
-
's3',
|
56 |
-
aws_access_key_id=S3_ACCESS_KEY_ID,
|
57 |
-
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
58 |
-
region_name=S3_REGION
|
59 |
-
)
|
60 |
-
|
61 |
-
# Define request models
|
62 |
-
class AudioRequest(BaseModel):
|
63 |
-
audio: str # Base64 encoded audio or video data
|
64 |
-
language: Optional[str] = None
|
65 |
-
|
66 |
-
class TTSRequest(BaseModel):
|
67 |
-
text: str
|
68 |
-
language: Optional[str] = None
|
69 |
-
speed: float = 1.0
|
70 |
-
|
71 |
-
class LanguageRequest(BaseModel):
|
72 |
-
language: Optional[str] = None
|
73 |
-
|
74 |
-
async def get_api_key(api_key_header: str = Security(api_key_header)):
|
75 |
-
if api_key_header == API_KEY:
|
76 |
-
return api_key_header
|
77 |
-
raise HTTPException(status_code=403, detail="Could not validate credentials")
|
78 |
-
|
79 |
-
def extract_audio_from_file(input_bytes):
|
80 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_file:
|
81 |
-
temp_file.write(input_bytes)
|
82 |
-
temp_file_path = temp_file.name
|
83 |
-
|
84 |
-
try:
|
85 |
-
# Log file info
|
86 |
-
file_info = magic.from_file(temp_file_path, mime=True)
|
87 |
-
logger.info(f"Received file of type: {file_info}")
|
88 |
-
|
89 |
-
# Try reading with soundfile first
|
90 |
-
try:
|
91 |
-
audio_array, sample_rate = sf.read(temp_file_path)
|
92 |
-
logger.info(f"Successfully read audio with soundfile. Shape: {audio_array.shape}, Sample rate: {sample_rate}")
|
93 |
-
return audio_array, sample_rate
|
94 |
-
except Exception as e:
|
95 |
-
logger.info(f"Could not read with soundfile: {str(e)}")
|
96 |
-
|
97 |
-
# Try reading as video
|
98 |
-
try:
|
99 |
-
video = VideoFileClip(temp_file_path)
|
100 |
-
audio = video.audio
|
101 |
-
if audio is not None:
|
102 |
-
audio_array = audio.to_soundarray()
|
103 |
-
sample_rate = audio.fps
|
104 |
-
audio_array = audio_array.mean(axis=1) if len(audio_array.shape) > 1 and audio_array.shape[1] > 1 else audio_array
|
105 |
-
audio_array = audio_array.astype(np.float32)
|
106 |
-
audio_array /= np.max(np.abs(audio_array))
|
107 |
-
video.close()
|
108 |
-
logger.info(f"Successfully extracted audio from video. Shape: {audio_array.shape}, Sample rate: {sample_rate}")
|
109 |
-
return audio_array, sample_rate
|
110 |
-
else:
|
111 |
-
logger.info("Video file contains no audio")
|
112 |
-
except Exception as e:
|
113 |
-
logger.info(f"Could not read as video: {str(e)}")
|
114 |
-
|
115 |
-
# Try reading with pydub
|
116 |
-
try:
|
117 |
-
audio = AudioSegment.from_file(temp_file_path)
|
118 |
-
audio_array = np.array(audio.get_array_of_samples())
|
119 |
-
audio_array = audio_array.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**7)
|
120 |
-
audio_array = audio_array.reshape((-1, 2)).mean(axis=1) if audio.channels == 2 else audio_array
|
121 |
-
logger.info(f"Successfully read audio with pydub. Shape: {audio_array.shape}, Sample rate: {audio.frame_rate}")
|
122 |
-
return audio_array, audio.frame_rate
|
123 |
-
except Exception as e:
|
124 |
-
logger.info(f"Could not read with pydub: {str(e)}")
|
125 |
-
|
126 |
-
raise ValueError(f"Unsupported file format: {file_info}")
|
127 |
-
finally:
|
128 |
-
os.unlink(temp_file_path)
|
129 |
-
|
130 |
-
@app.post("/transcribe")
|
131 |
-
async def transcribe_audio(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
|
132 |
-
start_time = time.time()
|
133 |
-
try:
|
134 |
-
input_bytes = base64.b64decode(request.audio)
|
135 |
-
audio_array, sample_rate = extract_audio_from_file(input_bytes)
|
136 |
-
|
137 |
-
# Ensure audio_array is float32
|
138 |
-
audio_array = audio_array.astype(np.float32)
|
139 |
-
|
140 |
-
# Resample if necessary
|
141 |
-
if sample_rate != ASR_SAMPLING_RATE:
|
142 |
-
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
|
143 |
-
|
144 |
-
if request.language is None:
|
145 |
-
# If no language is provided, use language identification
|
146 |
-
identified_language = identify(audio_array)
|
147 |
-
result = transcribe(audio_array, identified_language)
|
148 |
-
else:
|
149 |
-
result = transcribe(audio_array, request.language)
|
150 |
-
|
151 |
-
processing_time = time.time() - start_time
|
152 |
-
return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
|
153 |
-
except Exception as e:
|
154 |
-
logger.error(f"Error in transcribe_audio: {str(e)}", exc_info=True)
|
155 |
-
error_details = {
|
156 |
-
"error": str(e),
|
157 |
-
"traceback": traceback.format_exc()
|
158 |
-
}
|
159 |
-
processing_time = time.time() - start_time
|
160 |
-
return JSONResponse(
|
161 |
-
status_code=500,
|
162 |
-
content={"message": "An error occurred during transcription", "details": error_details, "processing_time_seconds": processing_time}
|
163 |
-
)
|
164 |
-
|
165 |
-
@app.post("/transcribe_file")
|
166 |
-
async def transcribe_audio_file(
|
167 |
-
file: UploadFile = File(...),
|
168 |
-
language: Optional[str] = Form(None),
|
169 |
-
api_key: APIKey = Depends(get_api_key)
|
170 |
-
):
|
171 |
-
start_time = time.time()
|
172 |
-
try:
|
173 |
-
contents = await file.read()
|
174 |
-
audio_array, sample_rate = extract_audio_from_file(contents)
|
175 |
|
176 |
-
|
177 |
-
audio_array = audio_array.astype(np.float32)
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
identified_language = identify(audio_array)
|
186 |
-
result = transcribe(audio_array, identified_language)
|
187 |
-
else:
|
188 |
-
result = transcribe(audio_array, language)
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
"error": str(e),
|
196 |
-
"traceback": traceback.format_exc()
|
197 |
-
}
|
198 |
-
processing_time = time.time() - start_time
|
199 |
-
return JSONResponse(
|
200 |
-
status_code=500,
|
201 |
-
content={"message": "An error occurred during transcription", "details": error_details, "processing_time_seconds": processing_time}
|
202 |
-
)
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
)
|
268 |
-
logger.info(f"File uploaded successfully to S3: {filename}")
|
269 |
-
|
270 |
-
# Generate the public URL with the correct format
|
271 |
-
url = f"https://s3.{S3_REGION}.amazonaws.com/{S3_BUCKET}/{filename}"
|
272 |
-
logger.info(f"Public URL generated: {url}")
|
273 |
-
|
274 |
-
processing_time = time.time() - start_time
|
275 |
-
return JSONResponse(content={"audio_url": url, "processing_time_seconds": processing_time})
|
276 |
-
|
277 |
-
except NoCredentialsError:
|
278 |
-
logger.error("AWS credentials not available or invalid")
|
279 |
-
raise HTTPException(status_code=500, detail="Could not upload file to S3: Missing or invalid credentials")
|
280 |
-
except Exception as e:
|
281 |
-
logger.error(f"Failed to upload to S3: {str(e)}")
|
282 |
-
raise HTTPException(status_code=500, detail=f"Could not upload file to S3: {str(e)}")
|
283 |
-
|
284 |
-
except ValueError as ve:
|
285 |
-
logger.error(f"ValueError in synthesize_speech: {str(ve)}", exc_info=True)
|
286 |
-
processing_time = time.time() - start_time
|
287 |
-
return JSONResponse(
|
288 |
-
status_code=400,
|
289 |
-
content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time}
|
290 |
-
)
|
291 |
-
except Exception as e:
|
292 |
-
logger.error(f"Unexpected error in synthesize_speech: {str(e)}", exc_info=True)
|
293 |
-
error_details = {
|
294 |
-
"error": str(e),
|
295 |
-
"type": type(e).__name__,
|
296 |
-
"traceback": traceback.format_exc()
|
297 |
-
}
|
298 |
-
processing_time = time.time() - start_time
|
299 |
-
return JSONResponse(
|
300 |
-
status_code=500,
|
301 |
-
content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time}
|
302 |
-
)
|
303 |
-
|
304 |
-
@app.post("/identify")
|
305 |
-
async def identify_language(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
|
306 |
-
start_time = time.time()
|
307 |
-
try:
|
308 |
-
input_bytes = base64.b64decode(request.audio)
|
309 |
-
audio_array, sample_rate = extract_audio_from_file(input_bytes)
|
310 |
-
result = identify(audio_array)
|
311 |
-
processing_time = time.time() - start_time
|
312 |
-
return JSONResponse(content={"language_identification": result, "processing_time_seconds": processing_time})
|
313 |
-
except Exception as e:
|
314 |
-
logger.error(f"Error in identify_language: {str(e)}", exc_info=True)
|
315 |
-
error_details = {
|
316 |
-
"error": str(e),
|
317 |
-
"traceback": traceback.format_exc()
|
318 |
-
}
|
319 |
-
processing_time = time.time() - start_time
|
320 |
-
return JSONResponse(
|
321 |
-
status_code=500,
|
322 |
-
content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time}
|
323 |
-
)
|
324 |
|
325 |
-
|
326 |
-
async def identify_language_file(
|
327 |
-
file: UploadFile = File(...),
|
328 |
-
api_key: APIKey = Depends(get_api_key)
|
329 |
-
):
|
330 |
-
start_time = time.time()
|
331 |
-
try:
|
332 |
-
contents = await file.read()
|
333 |
-
audio_array, sample_rate = extract_audio_from_file(contents)
|
334 |
-
result = identify(audio_array)
|
335 |
-
processing_time = time.time() - start_time
|
336 |
-
return JSONResponse(content={"language_identification": result, "processing_time_seconds": processing_time})
|
337 |
-
except Exception as e:
|
338 |
-
logger.error(f"Error in identify_language_file: {str(e)}", exc_info=True)
|
339 |
-
error_details = {
|
340 |
-
"error": str(e),
|
341 |
-
"traceback": traceback.format_exc()
|
342 |
-
}
|
343 |
-
processing_time = time.time() - start_time
|
344 |
-
return JSONResponse(
|
345 |
-
status_code=500,
|
346 |
-
content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time}
|
347 |
-
)
|
348 |
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
else:
|
357 |
-
matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
358 |
-
|
359 |
-
processing_time = time.time() - start_time
|
360 |
-
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
|
361 |
-
except Exception as e:
|
362 |
-
logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True)
|
363 |
-
error_details = {
|
364 |
-
"error": str(e),
|
365 |
-
"traceback": traceback.format_exc()
|
366 |
-
}
|
367 |
-
processing_time = time.time() - start_time
|
368 |
-
return JSONResponse(
|
369 |
-
status_code=500,
|
370 |
-
content={"message": "An error occurred while fetching ASR languages", "details": error_details, "processing_time_seconds": processing_time}
|
371 |
-
)
|
372 |
-
|
373 |
-
@app.post("/tts_languages")
|
374 |
-
async def get_tts_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
|
375 |
-
start_time = time.time()
|
376 |
-
try:
|
377 |
-
if request.language is None or request.language == "":
|
378 |
-
# If no language is provided, return all languages
|
379 |
-
matching_languages = TTS_LANGUAGES
|
380 |
-
else:
|
381 |
-
matching_languages = [lang for lang in TTS_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
382 |
-
|
383 |
-
processing_time = time.time() - start_time
|
384 |
-
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
|
385 |
-
except Exception as e:
|
386 |
-
logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True)
|
387 |
-
error_details = {
|
388 |
-
"error": str(e),
|
389 |
-
"traceback": traceback.format_exc()
|
390 |
-
}
|
391 |
-
processing_time = time.time() - start_time
|
392 |
-
return JSONResponse(
|
393 |
-
status_code=500,
|
394 |
-
content={"message": "An error occurred while fetching TTS languages", "details": error_details, "processing_time_seconds": processing_time}
|
395 |
-
)
|
396 |
-
|
397 |
-
@app.get("/health")
|
398 |
-
async def health_check():
|
399 |
-
return {"status": "ok"}
|
400 |
-
|
401 |
-
@app.get("/")
|
402 |
-
async def root():
|
403 |
-
return {
|
404 |
-
"message": "Welcome to the MMS Speech Technology API",
|
405 |
-
"version": "1.0",
|
406 |
-
"endpoints": [
|
407 |
-
"/transcribe",
|
408 |
-
"/transcribe_file",
|
409 |
-
"/synthesize",
|
410 |
-
"/identify",
|
411 |
-
"/identify_file",
|
412 |
-
"/asr_languages",
|
413 |
-
"/tts_languages",
|
414 |
-
"/health"
|
415 |
-
]
|
416 |
-
}
|
|
|
1 |
import os
|
2 |
+
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import tempfile
|
4 |
+
import torch
|
5 |
+
import sys
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
from huggingface_hub import hf_hub_download
|
|
|
10 |
|
11 |
+
# Setup TTS env
|
12 |
+
if "vits" not in sys.path:
|
13 |
+
sys.path.append("vits")
|
14 |
|
15 |
+
from vits import commons, utils
|
16 |
+
from vits.models import SynthesizerTrn
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
TTS_LANGUAGES = {}
|
19 |
+
with open(f"data/tts/all_langs.tsv") as f:
|
20 |
+
for line in f:
|
21 |
+
iso, name = line.split(" ", 1)
|
22 |
+
TTS_LANGUAGES[iso.strip()] = name.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
class TextMapper(object):
|
25 |
+
def __init__(self, vocab_file):
|
26 |
+
self.symbols = [
|
27 |
+
x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()
|
28 |
+
]
|
29 |
+
self.SPACE_ID = self.symbols.index(" ")
|
30 |
+
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
|
31 |
+
self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
|
32 |
+
|
33 |
+
def text_to_sequence(self, text, cleaner_names):
|
34 |
+
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
35 |
+
Args:
|
36 |
+
text: string to convert to a sequence
|
37 |
+
cleaner_names: names of the cleaner functions to run the text through
|
38 |
+
Returns:
|
39 |
+
List of integers corresponding to the symbols in the text
|
40 |
+
"""
|
41 |
+
sequence = []
|
42 |
+
clean_text = text.strip()
|
43 |
+
for symbol in clean_text:
|
44 |
+
symbol_id = self._symbol_to_id[symbol]
|
45 |
+
sequence += [symbol_id]
|
46 |
+
return sequence
|
47 |
+
|
48 |
+
def uromanize(self, text, uroman_pl):
|
49 |
+
iso = "xxx"
|
50 |
+
with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
|
51 |
+
with open(tf.name, "w") as f:
|
52 |
+
f.write("\n".join([text]))
|
53 |
+
cmd = f"perl " + uroman_pl
|
54 |
+
cmd += f" -l {iso} "
|
55 |
+
cmd += f" < {tf.name} > {tf2.name}"
|
56 |
+
os.system(cmd)
|
57 |
+
outtexts = []
|
58 |
+
with open(tf2.name) as f:
|
59 |
+
for line in f:
|
60 |
+
line = re.sub(r"\s+", " ", line).strip()
|
61 |
+
outtexts.append(line)
|
62 |
+
outtext = outtexts[0]
|
63 |
+
return outtext
|
64 |
+
|
65 |
+
def get_text(self, text, hps):
|
66 |
+
text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
|
67 |
+
if hps.data.add_blank:
|
68 |
+
text_norm = commons.intersperse(text_norm, 0)
|
69 |
+
text_norm = torch.LongTensor(text_norm)
|
70 |
+
return text_norm
|
71 |
+
|
72 |
+
def filter_oov(self, text, lang=None):
|
73 |
+
text = self.preprocess_char(text, lang=lang)
|
74 |
+
val_chars = self._symbol_to_id
|
75 |
+
txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
|
76 |
+
return txt_filt
|
77 |
+
|
78 |
+
def preprocess_char(self, text, lang=None):
|
79 |
+
"""
|
80 |
+
Special treatement of characters in certain languages
|
81 |
+
"""
|
82 |
+
if lang == "ron":
|
83 |
+
text = text.replace("ț", "ţ")
|
84 |
+
print(f"{lang} (ț -> ţ): {text}")
|
85 |
+
return text
|
86 |
+
|
87 |
+
def synthesize(text=None, lang=None, speed=None):
|
88 |
+
if speed is None:
|
89 |
+
speed = 1.0
|
90 |
+
|
91 |
+
lang_code = lang.split()[0].strip()
|
92 |
+
|
93 |
+
vocab_file = hf_hub_download(
|
94 |
+
repo_id="facebook/mms-tts",
|
95 |
+
filename="vocab.txt",
|
96 |
+
subfolder=f"models/{lang_code}",
|
97 |
+
)
|
98 |
+
config_file = hf_hub_download(
|
99 |
+
repo_id="facebook/mms-tts",
|
100 |
+
filename="config.json",
|
101 |
+
subfolder=f"models/{lang_code}",
|
102 |
+
)
|
103 |
+
g_pth = hf_hub_download(
|
104 |
+
repo_id="facebook/mms-tts",
|
105 |
+
filename="G_100000.pth",
|
106 |
+
subfolder=f"models/{lang_code}",
|
107 |
+
)
|
108 |
+
|
109 |
+
if torch.cuda.is_available():
|
110 |
+
device = torch.device("cuda")
|
111 |
+
elif (
|
112 |
+
hasattr(torch.backends, "mps")
|
113 |
+
and torch.backends.mps.is_available()
|
114 |
+
and torch.backends.mps.is_built()
|
115 |
+
):
|
116 |
+
device = torch.device("mps")
|
117 |
+
else:
|
118 |
+
device = torch.device("cpu")
|
119 |
+
|
120 |
+
print(f"Run inference with {device}")
|
121 |
+
|
122 |
+
assert os.path.isfile(config_file), f"{config_file} doesn't exist"
|
123 |
+
hps = utils.get_hparams_from_file(config_file)
|
124 |
+
text_mapper = TextMapper(vocab_file)
|
125 |
+
|
126 |
+
net_g = SynthesizerTrn(
|
127 |
+
len(text_mapper.symbols),
|
128 |
+
hps.data.filter_length // 2 + 1,
|
129 |
+
hps.train.segment_size // hps.data.hop_length,
|
130 |
+
**hps.model,
|
131 |
+
).to(device)
|
132 |
+
net_g.eval()
|
133 |
+
|
134 |
+
_ = utils.load_checkpoint(g_pth, net_g, None)
|
135 |
+
|
136 |
+
is_uroman = hps.data.training_files.split(".")[-1] == "uroman"
|
137 |
+
|
138 |
+
if is_uroman:
|
139 |
+
uroman_dir = "uroman"
|
140 |
+
assert os.path.exists(uroman_dir)
|
141 |
+
uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
|
142 |
+
text = text_mapper.uromanize(text, uroman_pl)
|
143 |
+
|
144 |
+
text = text.lower()
|
145 |
+
text = text_mapper.filter_oov(text, lang=lang)
|
146 |
+
stn_tst = text_mapper.get_text(text, hps).to(device)
|
147 |
+
|
148 |
+
# Use autocast for mixed-precision inference
|
149 |
+
with torch.cuda.amp.autocast(enabled=True):
|
150 |
+
with torch.no_grad():
|
151 |
+
x_tst = stn_tst.unsqueeze(0)
|
152 |
+
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
|
153 |
+
hyp = (
|
154 |
+
net_g.infer(
|
155 |
+
x_tst,
|
156 |
+
x_tst_lengths,
|
157 |
+
noise_scale=0.667,
|
158 |
+
noise_scale_w=0.8,
|
159 |
+
length_scale=1.0 / speed,
|
160 |
+
)[0][0, 0]
|
161 |
+
.cpu()
|
162 |
+
.float() # Convert to float32 for numpy
|
163 |
+
.numpy()
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
+
return (hps.data.sampling_rate, hyp), text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
TTS_EXAMPLES = [
|
169 |
+
["I am going to the store.", "eng (English)", 1.0],
|
170 |
+
["안녕하세요.", "kor (Korean)", 1.0],
|
171 |
+
["क्या मुझे पीने का पानी मिल सकता है?", "hin (Hindi)", 1.0],
|
172 |
+
["Tanış olmağıma çox şadam", "azj-script_latin (Azerbaijani, North)", 1.0],
|
173 |
+
["Mu zo murna a cikin ƙasar.", "hau (Hausa)", 1.0],
|
174 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|