Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
import numpy as np | |
import io | |
import soundfile as sf | |
import base64 | |
import logging | |
import torch | |
import librosa | |
from pathlib import Path | |
from pydub import AudioSegment | |
from moviepy.editor import VideoFileClip | |
import traceback | |
from logging.handlers import RotatingFileHandler | |
import os | |
import boto3 | |
from botocore.exceptions import NoCredentialsError | |
import time | |
import tempfile | |
# Import functions from other modules | |
from asr import transcribe, ASR_LANGUAGES | |
from tts import synthesize, TTS_LANGUAGES | |
from lid import identify | |
from asr import ASR_SAMPLING_RATE | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Add a file handler | |
file_handler = RotatingFileHandler('app.log', maxBytes=10000000, backupCount=5) | |
file_handler.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages") | |
# S3 Configuration | |
S3_BUCKET = os.environ.get("S3_BUCKET") | |
S3_REGION = os.environ.get("S3_REGION") | |
S3_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID") | |
S3_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY") | |
# Initialize S3 client | |
s3_client = boto3.client( | |
's3', | |
aws_access_key_id=S3_ACCESS_KEY_ID, | |
aws_secret_access_key=S3_SECRET_ACCESS_KEY, | |
region_name=S3_REGION | |
) | |
# Define request models | |
class AudioRequest(BaseModel): | |
audio: str # Base64 encoded audio or video data | |
language: str | |
class TTSRequest(BaseModel): | |
text: str | |
language: str | |
speed: float | |
def extract_audio_from_file(input_bytes): | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_file: | |
temp_file.write(input_bytes) | |
temp_file_path = temp_file.name | |
try: | |
# First, try to read as a standard audio file | |
audio_array, sample_rate = sf.read(temp_file_path) | |
return audio_array, sample_rate | |
except Exception: | |
try: | |
# Try to read as a video file | |
video = VideoFileClip(temp_file_path) | |
audio = video.audio | |
if audio is not None: | |
# Extract audio from video | |
audio_array = audio.to_soundarray() | |
sample_rate = audio.fps | |
# Convert to mono if stereo | |
if len(audio_array.shape) > 1 and audio_array.shape[1] > 1: | |
audio_array = audio_array.mean(axis=1) | |
# Ensure audio is float32 and normalized | |
audio_array = audio_array.astype(np.float32) | |
audio_array /= np.max(np.abs(audio_array)) | |
video.close() | |
return audio_array, sample_rate | |
else: | |
raise ValueError("Video file contains no audio") | |
except Exception: | |
# If video reading fails, try as generic audio with pydub | |
try: | |
audio = AudioSegment.from_file(temp_file_path) | |
audio_array = np.array(audio.get_array_of_samples()) | |
# Convert to float32 and normalize | |
audio_array = audio_array.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**7) | |
# Convert stereo to mono if necessary | |
if audio.channels == 2: | |
audio_array = audio_array.reshape((-1, 2)).mean(axis=1) | |
return audio_array, audio.frame_rate | |
except Exception as e: | |
raise ValueError(f"Unsupported file format: {str(e)}") | |
finally: | |
# Clean up the temporary file | |
os.unlink(temp_file_path) | |
async def transcribe_audio(request: AudioRequest): | |
start_time = time.time() | |
try: | |
input_bytes = base64.b64decode(request.audio) | |
audio_array, sample_rate = extract_audio_from_file(input_bytes) | |
# Ensure audio_array is float32 | |
audio_array = audio_array.astype(np.float32) | |
# Resample if necessary | |
if sample_rate != ASR_SAMPLING_RATE: | |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE) | |
result = transcribe(audio_array, request.language) | |
processing_time = time.time() - start_time | |
return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time}) | |
except Exception as e: | |
logger.error(f"Error in transcribe_audio: {str(e)}", exc_info=True) | |
error_details = { | |
"error": str(e), | |
"traceback": traceback.format_exc() | |
} | |
processing_time = time.time() - start_time | |
return JSONResponse( | |
status_code=500, | |
content={"message": "An error occurred during transcription", "details": error_details, "processing_time_seconds": processing_time} | |
) | |
async def synthesize_speech(request: TTSRequest): | |
start_time = time.time() | |
logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}") | |
try: | |
# Extract the ISO code from the full language name | |
lang_code = request.language.split()[0].strip() | |
# Input validation | |
if not request.text: | |
raise ValueError("Text cannot be empty") | |
if lang_code not in TTS_LANGUAGES: | |
raise ValueError(f"Unsupported language: {request.language}") | |
if not 0.5 <= request.speed <= 2.0: | |
raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}") | |
logger.info(f"Calling synthesize function with lang_code: {lang_code}") | |
result, filtered_text = synthesize(request.text, request.language, request.speed) | |
logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'") | |
if result is None: | |
logger.error("Synthesize function returned None") | |
raise ValueError("Synthesis failed to produce audio") | |
sample_rate, audio = result | |
logger.info(f"Synthesis result: sample_rate={sample_rate}, audio_shape={audio.shape if isinstance(audio, np.ndarray) else 'not numpy array'}, audio_dtype={audio.dtype if isinstance(audio, np.ndarray) else type(audio)}") | |
logger.info("Converting audio to numpy array") | |
audio = np.array(audio, dtype=np.float32) | |
logger.info(f"Converted audio shape: {audio.shape}, dtype: {audio.dtype}") | |
logger.info("Normalizing audio") | |
max_value = np.max(np.abs(audio)) | |
if max_value == 0: | |
logger.warning("Audio array is all zeros") | |
raise ValueError("Generated audio is silent (all zeros)") | |
audio = audio / max_value | |
logger.info(f"Normalized audio range: [{audio.min()}, {audio.max()}]") | |
logger.info("Converting to int16") | |
audio = (audio * 32767).astype(np.int16) | |
logger.info(f"Int16 audio shape: {audio.shape}, dtype: {audio.dtype}") | |
logger.info("Writing audio to buffer") | |
buffer = io.BytesIO() | |
sf.write(buffer, audio, sample_rate, format='wav') | |
buffer.seek(0) | |
logger.info(f"Buffer size: {buffer.getbuffer().nbytes} bytes") | |
# Generate a unique filename | |
filename = f"synthesized_audio_{int(time.time())}.wav" | |
# Upload to S3 without ACL | |
try: | |
s3_client.upload_fileobj( | |
buffer, | |
S3_BUCKET, | |
filename, | |
ExtraArgs={'ContentType': 'audio/wav'} | |
) | |
logger.info(f"File uploaded successfully to S3: {filename}") | |
# Generate the public URL with the correct format | |
url = f"https://s3.{S3_REGION}.amazonaws.com/{S3_BUCKET}/{filename}" | |
logger.info(f"Public URL generated: {url}") | |
processing_time = time.time() - start_time | |
return JSONResponse(content={"audio_url": url, "processing_time_seconds": processing_time}) | |
except NoCredentialsError: | |
logger.error("AWS credentials not available or invalid") | |
raise HTTPException(status_code=500, detail="Could not upload file to S3: Missing or invalid credentials") | |
except Exception as e: | |
logger.error(f"Failed to upload to S3: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Could not upload file to S3: {str(e)}") | |
except ValueError as ve: | |
logger.error(f"ValueError in synthesize_speech: {str(ve)}", exc_info=True) | |
processing_time = time.time() - start_time | |
return JSONResponse( | |
status_code=400, | |
content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time} | |
) | |
except Exception as e: | |
logger.error(f"Unexpected error in synthesize_speech: {str(e)}", exc_info=True) | |
error_details = { | |
"error": str(e), | |
"type": type(e).__name__, | |
"traceback": traceback.format_exc() | |
} | |
processing_time = time.time() - start_time | |
return JSONResponse( | |
status_code=500, | |
content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time} | |
) | |
finally: | |
logger.info("Synthesize request completed") | |
async def identify_language(request: AudioRequest): | |
start_time = time.time() | |
try: | |
input_bytes = base64.b64decode(request.audio) | |
audio_array, sample_rate = extract_audio_from_file(input_bytes) | |
result = identify(audio_array) | |
processing_time = time.time() - start_time | |
return JSONResponse(content={"language_identification": result, "processing_time_seconds": processing_time}) | |
except Exception as e: | |
logger.error(f"Error in identify_language: {str(e)}", exc_info=True) | |
error_details = { | |
"error": str(e), | |
"traceback": traceback.format_exc() | |
} | |
processing_time = time.time() - start_time | |
return JSONResponse( | |
status_code=500, | |
content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time} | |
) | |
async def get_asr_languages(): | |
start_time = time.time() | |
try: | |
processing_time = time.time() - start_time | |
return JSONResponse(content={"languages": ASR_LANGUAGES, "processing_time_seconds": processing_time}) | |
except Exception as e: | |
logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True) | |
error_details = { | |
"error": str(e), | |
"traceback": traceback.format_exc() | |
} | |
processing_time = time.time() - start_time | |
return JSONResponse( | |
status_code=500, | |
content={"message": "An error occurred while fetching ASR languages", "details": error_details, "processing_time_seconds": processing_time} | |
) | |
async def get_tts_languages(): | |
start_time = time.time() | |
try: | |
processing_time = time.time() - start_time | |
return JSONResponse(content={"languages": TTS_LANGUAGES, "processing_time_seconds": processing_time}) | |
except Exception as e: | |
logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True) | |
error_details = { | |
"error": str(e), | |
"traceback": traceback.format_exc() | |
} | |
processing_time = time.time() - start_time | |
return JSONResponse( | |
status_code=500, | |
content={"message": "An error occurred while fetching TTS languages", "details": error_details, "processing_time_seconds": processing_time} | |
) |