Update tts.py
Browse files
@@ -1,416 +1,174 @@
1 |
import os
2 |
3 |
from fastapi.security.api_key import APIKeyHeader, APIKey
4 |
from fastapi.responses import JSONResponse
5 |
from pydantic import BaseModel
6 |
from typing import Optional
7 |
import numpy as np
8 |
import io
9 |
import soundfile as sf
10 |
import base64
11 |
import logging
12 |
import torch
13 |
import librosa
14 |
from pathlib import Path
15 |
from pydub import AudioSegment
16 |
from moviepy.editor import VideoFileClip
17 |
import traceback
18 |
from logging.handlers import RotatingFileHandler
19 |
import boto3
20 |
from botocore.exceptions import NoCredentialsError
21 |
import time
22 |
import tempfile
23 |
24 |
25 |
26 |
27 |
from tts import synthesize, TTS_LANGUAGES
28 |
from lid import identify
29 |
30 |
# Configure logging
31 |
32 |
logger = logging.getLogger(__name__)
33 |
34 |
# Add a file handler
35 |
file_handler = RotatingFileHandler('app.log', maxBytes=10000000, backupCount=5)
36 |
37 |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
38 |
39 |
40 |
41 |
app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages")
42 |
43 |
# S3 Configuration
44 |
S3_BUCKET = os.environ.get("S3_BUCKET")
45 |
S3_REGION = os.environ.get("S3_REGION")
46 |
S3_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID")
47 |
48 |
49 |
# API Key Configuration
50 |
API_KEY = os.environ.get("API_KEY")
51 |
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
52 |
53 |
# Initialize S3 client
54 |
s3_client = boto3.client(
55 |
56 |
57 |
58 |
59 |
60 |
61 |
# Define request models
62 |
class AudioRequest(BaseModel):
63 |
audio: str # Base64 encoded audio or video data
64 |
language: Optional[str] = None
65 |
66 |
class TTSRequest(BaseModel):
67 |
text: str
68 |
language: Optional[str] = None
69 |
speed: float = 1.0
70 |
71 |
class LanguageRequest(BaseModel):
72 |
language: Optional[str] = None
73 |
74 |
async def get_api_key(api_key_header: str = Security(api_key_header)):
75 |
if api_key_header == API_KEY:
76 |
return api_key_header
77 |
raise HTTPException(status_code=403, detail="Could not validate credentials")
78 |
79 |
def extract_audio_from_file(input_bytes):
80 |
with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_file:
81 |
82 |
temp_file_path = temp_file.name
83 |
84 |
85 |
# Log file info
86 |
file_info = magic.from_file(temp_file_path, mime=True)
87 |
logger.info(f"Received file of type: {file_info}")
88 |
89 |
# Try reading with soundfile first
90 |
91 |
audio_array, sample_rate = sf.read(temp_file_path)
92 |
logger.info(f"Successfully read audio with soundfile. Shape: {audio_array.shape}, Sample rate: {sample_rate}")
93 |
return audio_array, sample_rate
94 |
except Exception as e:
95 |
logger.info(f"Could not read with soundfile: {str(e)}")
96 |
97 |
# Try reading as video
98 |
99 |
video = VideoFileClip(temp_file_path)
100 |
audio = video.audio
101 |
if audio is not None:
102 |
audio_array = audio.to_soundarray()
103 |
sample_rate = audio.fps
104 |
audio_array = audio_array.mean(axis=1) if len(audio_array.shape) > 1 and audio_array.shape[1] > 1 else audio_array
105 |
audio_array = audio_array.astype(np.float32)
106 |
audio_array /= np.max(np.abs(audio_array))
107 |
108 |
logger.info(f"Successfully extracted audio from video. Shape: {audio_array.shape}, Sample rate: {sample_rate}")
109 |
return audio_array, sample_rate
110 |
111 |
logger.info("Video file contains no audio")
112 |
except Exception as e:
113 |
logger.info(f"Could not read as video: {str(e)}")
114 |
115 |
# Try reading with pydub
116 |
117 |
audio = AudioSegment.from_file(temp_file_path)
118 |
audio_array = np.array(audio.get_array_of_samples())
119 |
audio_array = audio_array.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**7)
120 |
audio_array = audio_array.reshape((-1, 2)).mean(axis=1) if audio.channels == 2 else audio_array
121 |
logger.info(f"Successfully read audio with pydub. Shape: {audio_array.shape}, Sample rate: {audio.frame_rate}")
122 |
return audio_array, audio.frame_rate
123 |
except Exception as e:
124 |
logger.info(f"Could not read with pydub: {str(e)}")
125 |
126 |
raise ValueError(f"Unsupported file format: {file_info}")
127 |
128 |
129 |
130 |
131 |
async def transcribe_audio(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
132 |
start_time = time.time()
133 |
134 |
input_bytes = base64.b64decode(request.audio)
135 |
audio_array, sample_rate = extract_audio_from_file(input_bytes)
136 |
137 |
# Ensure audio_array is float32
138 |
audio_array = audio_array.astype(np.float32)
139 |
140 |
# Resample if necessary
141 |
if sample_rate != ASR_SAMPLING_RATE:
142 |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
143 |
144 |
if request.language is None:
145 |
# If no language is provided, use language identification
146 |
identified_language = identify(audio_array)
147 |
result = transcribe(audio_array, identified_language)
148 |
149 |
result = transcribe(audio_array, request.language)
150 |
151 |
processing_time = time.time() - start_time
152 |
return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
153 |
except Exception as e:
154 |
logger.error(f"Error in transcribe_audio: {str(e)}", exc_info=True)
155 |
error_details = {
156 |
"error": str(e),
157 |
"traceback": traceback.format_exc()
158 |
159 |
processing_time = time.time() - start_time
160 |
return JSONResponse(
161 |
162 |
content={"message": "An error occurred during transcription", "details": error_details, "processing_time_seconds": processing_time}
163 |
164 |
165 |
166 |
async def transcribe_audio_file(
167 |
file: UploadFile = File(...),
168 |
language: Optional[str] = Form(None),
169 |
api_key: APIKey = Depends(get_api_key)
170 |
171 |
start_time = time.time()
172 |
173 |
contents = await file.read()
174 |
audio_array, sample_rate = extract_audio_from_file(contents)
175 |
176 |
177 |
audio_array = audio_array.astype(np.float32)
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
identified_language = identify(audio_array)
186 |
result = transcribe(audio_array, identified_language)
187 |
188 |
result = transcribe(audio_array, language)
189 |
190 |
191 |
192 |
193 |
194 |
195 |
"error": str(e),
196 |
"traceback": traceback.format_exc()
197 |
198 |
processing_time = time.time() - start_time
199 |
return JSONResponse(
200 |
201 |
content={"message": "An error occurred during transcription", "details": error_details, "processing_time_seconds": processing_time}
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
logger.info(f"File uploaded successfully to S3: {filename}")
269 |
270 |
# Generate the public URL with the correct format
271 |
url = f"https://s3.{S3_REGION}.amazonaws.com/{S3_BUCKET}/{filename}"
272 |
logger.info(f"Public URL generated: {url}")
273 |
274 |
processing_time = time.time() - start_time
275 |
return JSONResponse(content={"audio_url": url, "processing_time_seconds": processing_time})
276 |
277 |
except NoCredentialsError:
278 |
logger.error("AWS credentials not available or invalid")
279 |
raise HTTPException(status_code=500, detail="Could not upload file to S3: Missing or invalid credentials")
280 |
except Exception as e:
281 |
logger.error(f"Failed to upload to S3: {str(e)}")
282 |
raise HTTPException(status_code=500, detail=f"Could not upload file to S3: {str(e)}")
283 |
284 |
except ValueError as ve:
285 |
logger.error(f"ValueError in synthesize_speech: {str(ve)}", exc_info=True)
286 |
processing_time = time.time() - start_time
287 |
return JSONResponse(
288 |
289 |
content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time}
290 |
291 |
except Exception as e:
292 |
logger.error(f"Unexpected error in synthesize_speech: {str(e)}", exc_info=True)
293 |
error_details = {
294 |
"error": str(e),
295 |
"type": type(e).__name__,
296 |
"traceback": traceback.format_exc()
297 |
298 |
processing_time = time.time() - start_time
299 |
return JSONResponse(
300 |
301 |
content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time}
302 |
303 |
304 |
305 |
async def identify_language(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
306 |
start_time = time.time()
307 |
308 |
input_bytes = base64.b64decode(request.audio)
309 |
audio_array, sample_rate = extract_audio_from_file(input_bytes)
310 |
result = identify(audio_array)
311 |
processing_time = time.time() - start_time
312 |
return JSONResponse(content={"language_identification": result, "processing_time_seconds": processing_time})
313 |
except Exception as e:
314 |
logger.error(f"Error in identify_language: {str(e)}", exc_info=True)
315 |
error_details = {
316 |
"error": str(e),
317 |
"traceback": traceback.format_exc()
318 |
319 |
processing_time = time.time() - start_time
320 |
return JSONResponse(
321 |
322 |
content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time}
323 |
324 |
325 |
326 |
async def identify_language_file(
327 |
file: UploadFile = File(...),
328 |
api_key: APIKey = Depends(get_api_key)
329 |
330 |
start_time = time.time()
331 |
332 |
contents = await file.read()
333 |
audio_array, sample_rate = extract_audio_from_file(contents)
334 |
result = identify(audio_array)
335 |
processing_time = time.time() - start_time
336 |
return JSONResponse(content={"language_identification": result, "processing_time_seconds": processing_time})
337 |
except Exception as e:
338 |
logger.error(f"Error in identify_language_file: {str(e)}", exc_info=True)
339 |
error_details = {
340 |
"error": str(e),
341 |
"traceback": traceback.format_exc()
342 |
343 |
processing_time = time.time() - start_time
344 |
return JSONResponse(
345 |
346 |
content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time}
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
358 |
359 |
processing_time = time.time() - start_time
360 |
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
361 |
except Exception as e:
362 |
logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True)
363 |
error_details = {
364 |
"error": str(e),
365 |
"traceback": traceback.format_exc()
366 |
367 |
processing_time = time.time() - start_time
368 |
return JSONResponse(
369 |
370 |
content={"message": "An error occurred while fetching ASR languages", "details": error_details, "processing_time_seconds": processing_time}
371 |
372 |
373 |
374 |
async def get_tts_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
375 |
start_time = time.time()
376 |
377 |
if request.language is None or request.language == "":
378 |
# If no language is provided, return all languages
379 |
matching_languages = TTS_LANGUAGES
380 |
381 |
matching_languages = [lang for lang in TTS_LANGUAGES if lang.lower().startswith(request.language.lower())]
382 |
383 |
processing_time = time.time() - start_time
384 |
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
385 |
except Exception as e:
386 |
logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True)
387 |
error_details = {
388 |
"error": str(e),
389 |
"traceback": traceback.format_exc()
390 |
391 |
processing_time = time.time() - start_time
392 |
return JSONResponse(
393 |
394 |
content={"message": "An error occurred while fetching TTS languages", "details": error_details, "processing_time_seconds": processing_time}
395 |
396 |
397 |
398 |
async def health_check():
399 |
return {"status": "ok"}
400 |
401 |
402 |
async def root():
403 |
return {
404 |
"message": "Welcome to the MMS Speech Technology API",
405 |
"version": "1.0",
406 |
"endpoints": [
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
1 |
import os
2 |
import re
3 |
import tempfile
4 |
import torch
5 |
import sys
6 |
import gradio as gr
7 |
import numpy as np
8 |
9 |
from huggingface_hub import hf_hub_download
10 |
11 |
# Setup TTS env
12 |
if "vits" not in sys.path:
13 |
14 |
15 |
from vits import commons, utils
16 |
from vits.models import SynthesizerTrn
17 |
18 |
19 |
with open(f"data/tts/all_langs.tsv") as f:
20 |
for line in f:
21 |
iso, name = line.split(" ", 1)
22 |
TTS_LANGUAGES[iso.strip()] = name.strip()
23 |
24 |
class TextMapper(object):
25 |
def __init__(self, vocab_file):
26 |
self.symbols = [
27 |
x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()
28 |
29 |
self.SPACE_ID = self.symbols.index(" ")
30 |
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
31 |
self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
32 |
33 |
def text_to_sequence(self, text, cleaner_names):
34 |
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
35 |
36 |
text: string to convert to a sequence
37 |
cleaner_names: names of the cleaner functions to run the text through
38 |
39 |
List of integers corresponding to the symbols in the text
40 |
41 |
sequence = []
42 |
clean_text = text.strip()
43 |
for symbol in clean_text:
44 |
symbol_id = self._symbol_to_id[symbol]
45 |
sequence += [symbol_id]
46 |
return sequence
47 |
48 |
def uromanize(self, text, uroman_pl):
49 |
iso = "xxx"
50 |
with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
51 |
with open(tf.name, "w") as f:
52 |
53 |
cmd = f"perl " + uroman_pl
54 |
cmd += f" -l {iso} "
55 |
cmd += f" < {tf.name} > {tf2.name}"
56 |
57 |
outtexts = []
58 |
with open(tf2.name) as f:
59 |
for line in f:
60 |
line = re.sub(r"\s+", " ", line).strip()
61 |
62 |
outtext = outtexts[0]
63 |
return outtext
64 |
65 |
def get_text(self, text, hps):
66 |
text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
67 |
if hps.data.add_blank:
68 |
text_norm = commons.intersperse(text_norm, 0)
69 |
text_norm = torch.LongTensor(text_norm)
70 |
return text_norm
71 |
72 |
def filter_oov(self, text, lang=None):
73 |
text = self.preprocess_char(text, lang=lang)
74 |
val_chars = self._symbol_to_id
75 |
txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
76 |
return txt_filt
77 |
78 |
def preprocess_char(self, text, lang=None):
79 |
80 |
Special treatement of characters in certain languages
81 |
82 |
if lang == "ron":
83 |
text = text.replace("ț", "ţ")
84 |
print(f"{lang} (ț -> ţ): {text}")
85 |
return text
86 |
87 |
def synthesize(text=None, lang=None, speed=None):
88 |
if speed is None:
89 |
speed = 1.0
90 |
91 |
lang_code = lang.split()[0].strip()
92 |
93 |
vocab_file = hf_hub_download(
94 |
95 |
96 |
97 |
98 |
config_file = hf_hub_download(
99 |
100 |
101 |
102 |
103 |
g_pth = hf_hub_download(
104 |
105 |
106 |
107 |
108 |
109 |
if torch.cuda.is_available():
110 |
device = torch.device("cuda")
111 |
elif (
112 |
hasattr(torch.backends, "mps")
113 |
and torch.backends.mps.is_available()
114 |
and torch.backends.mps.is_built()
115 |
116 |
device = torch.device("mps")
117 |
118 |
device = torch.device("cpu")
119 |
120 |
print(f"Run inference with {device}")
121 |
122 |
assert os.path.isfile(config_file), f"{config_file} doesn't exist"
123 |
hps = utils.get_hparams_from_file(config_file)
124 |
text_mapper = TextMapper(vocab_file)
125 |
126 |
net_g = SynthesizerTrn(
127 |
128 |
hps.data.filter_length // 2 + 1,
129 |
hps.train.segment_size // hps.data.hop_length,
130 |
131 |
132 |
133 |
134 |
_ = utils.load_checkpoint(g_pth, net_g, None)
135 |
136 |
is_uroman = hps.data.training_files.split(".")[-1] == "uroman"
137 |
138 |
if is_uroman:
139 |
uroman_dir = "uroman"
140 |
assert os.path.exists(uroman_dir)
141 |
uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
142 |
text = text_mapper.uromanize(text, uroman_pl)
143 |
144 |
text = text.lower()
145 |
text = text_mapper.filter_oov(text, lang=lang)
146 |
stn_tst = text_mapper.get_text(text, hps).to(device)
147 |
148 |
# Use autocast for mixed-precision inference
149 |
with torch.cuda.amp.autocast(enabled=True):
150 |
with torch.no_grad():
151 |
x_tst = stn_tst.unsqueeze(0)
152 |
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
153 |
hyp = (
154 |
155 |
156 |
157 |
158 |
159 |
length_scale=1.0 / speed,
160 |
)[0][0, 0]
161 |
162 |
.float() # Convert to float32 for numpy
163 |
164 |
165 |
166 |
return (hps.data.sampling_rate, hyp), text
167 |
168 |
169 |
["I am going to the store.", "eng (English)", 1.0],
170 |
["안녕하세요.", "kor (Korean)", 1.0],
171 |
["क्या मुझे पीने का पानी मिल सकता है?", "hin (Hindi)", 1.0],
172 |
["Tanış olmağıma çox şadam", "azj-script_latin (Azerbaijani, North)", 1.0],
173 |
["Mu zo murna a cikin ƙasar.", "hau (Hausa)", 1.0],
174 |