from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel import numpy as np import io import soundfile as sf import logging import librosa from pydub import AudioSegment from moviepy.editor import VideoFileClip import traceback from logging.handlers import RotatingFileHandler import os import time import tempfile import asyncio import torch from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Add a file handler file_handler = RotatingFileHandler('app.log', maxBytes=10000000, backupCount=5) file_handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter) logger.addHandler(file_handler) # Import functions from other modules from asr import transcribe, ASR_LANGUAGES from tts import synthesize, TTS_LANGUAGES from lid import identify from asr import ASR_SAMPLING_RATE app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages") # Hardcoded API Key API_KEY = "africa-best-app" # Language Identification Configuration model_id = "facebook/mms-lid-1024" processor = AutoFeatureExtractor.from_pretrained(model_id) model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id) LID_SAMPLING_RATE = 16_000 LID_THRESHOLD = 0.33 # Load language mappings LID_LANGUAGES = {} with open("data/lid/all_langs.tsv") as f: for line in f: iso, name = line.split(" ", 1) LID_LANGUAGES[iso] = name.strip() # Define request models class AudioRequest(BaseModel): language: str class TTSRequest(BaseModel): text: str language: str speed: float # Dependency to verify API key async def verify_api_key(api_key: str = Header(None, alias="X-API-Key")): if api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid API Key") return api_key async def extract_audio_from_file(input_bytes): """Extracts audio from a file (audio or video) and returns the audio array and sample rate.""" with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_file: temp_file.write(input_bytes) temp_file_path = temp_file.name try: # First, try to read as a standard audio file audio_array, sample_rate = sf.read(temp_file_path) return audio_array, sample_rate except Exception: try: # Try to read as a video file video = VideoFileClip(temp_file_path) audio = video.audio if audio is not None: # Extract audio from video audio_array = audio.to_soundarray() sample_rate = audio.fps # Convert to mono if stereo if len(audio_array.shape) > 1 and audio_array.shape[1] > 1: audio_array = audio_array.mean(axis=1) # Ensure audio is float32 and normalized audio_array = audio_array.astype(np.float32) audio_array /= np.max(np.abs(audio_array)) video.close() return audio_array, sample_rate else: raise ValueError("Video file contains no audio") except Exception: # If video reading fails, try as generic audio with pydub try: audio = AudioSegment.from_file(temp_file_path) audio_array = np.array(audio.get_array_of_samples()) # Convert to float32 and normalize audio_array = audio_array.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**7) # Convert stereo to mono if necessary if audio.channels == 2: audio_array = audio_array.reshape((-1, 2)).mean(axis=1) return audio_array, audio.frame_rate except Exception as e: raise ValueError(f"Unsupported file format: {str(e)}") finally: # Clean up the temporary file os.unlink(temp_file_path) def identify(audio_data): """Identify the language of the given audio data.""" if isinstance(audio_data, tuple): # Microphone input sr, audio_samples = audio_data audio_samples = (audio_samples / 32768.0).astype(np.float32) if sr != LID_SAMPLING_RATE: audio_samples = librosa.resample( audio_samples, orig_sr=sr, target_sr=LID_SAMPLING_RATE ) else: # File upload audio_samples = librosa.load(audio_data, sr=LID_SAMPLING_RATE, mono=True)[0] # Process audio inputs = processor( audio_samples, sampling_rate=LID_SAMPLING_RATE, return_tensors="pt" ) # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) inputs = inputs.to(device) # Perform inference with torch.no_grad(): logit = model(**inputs).logits logit_lsm = torch.log_softmax(logit.squeeze(), dim=-1) scores, indices = torch.topk(logit_lsm, 5, dim=-1) scores, indices = torch.exp(scores).to("cpu").tolist(), indices.to("cpu").tolist() iso2score = {model.config.id2label[int(i)]: s for s, i in zip(scores, indices)} # Filter results based on confidence threshold if max(iso2score.values()) < LID_THRESHOLD: return "Low confidence in the language identification predictions. Output is not shown!" return {LID_LANGUAGES[iso]: score for iso, score in iso2score.items()} @app.post("/transcribe") async def transcribe_audio( language: str, file: UploadFile = File(...), api_key: str = Depends(verify_api_key) ): start_time = time.time() try: input_bytes = await file.read() audio_array, sample_rate = await extract_audio_from_file(input_bytes) # Ensure audio_array is float32 audio_array = audio_array.astype(np.float32) # Resample if necessary if sample_rate != ASR_SAMPLING_RATE: audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE) result = await asyncio.to_thread(transcribe, audio_array, language) processing_time = time.time() - start_time return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time}) except Exception as e: logger.error(f"Error in transcribe_audio: {str(e)}", exc_info=True) error_details = { "error": str(e), "traceback": traceback.format_exc() } processing_time = time.time() - start_time return JSONResponse( status_code=500, content={"message": "An error occurred during transcription", "details": error_details, "processing_time_seconds": processing_time} ) @app.post("/synthesize") async def synthesize_speech( request: TTSRequest, api_key: str = Depends(verify_api_key) ): start_time = time.time() logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}") try: # Extract the ISO code from the full language name lang_code = request.language.split()[0].strip() # Input validation if not request.text: raise ValueError("Text cannot be empty") if lang_code not in TTS_LANGUAGES: raise ValueError(f"Unsupported language: {request.language}") if not 0.5 <= request.speed <= 2.0: raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}") logger.info(f"Calling synthesize function with lang_code: {lang_code}") result, filtered_text = await asyncio.to_thread(synthesize, request.text, request.language, request.speed) logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'") if result is None: logger.error("Synthesize function returned None") raise ValueError("Synthesis failed to produce audio") sample_rate, audio = result logger.info(f"Synthesis result: sample_rate={sample_rate}, audio_shape={audio.shape if isinstance(audio, np.ndarray) else 'not numpy array'}, audio_dtype={audio.dtype if isinstance(audio, np.ndarray) else type(audio)}") logger.info("Converting audio to numpy array") audio = np.array(audio, dtype=np.float32) logger.info(f"Converted audio shape: {audio.shape}, dtype: {audio.dtype}") logger.info("Normalizing audio") max_value = np.max(np.abs(audio)) if max_value == 0: logger.warning("Audio array is all zeros") raise ValueError("Generated audio is silent (all zeros)") audio = audio / max_value logger.info(f"Normalized audio range: [{audio.min()}, {audio.max()}]") logger.info("Converting to int16") audio = (audio * 32767).astype(np.int16) logger.info(f"Int16 audio shape: {audio.shape}, dtype: {audio.dtype}") logger.info("Writing audio to buffer") buffer = io.BytesIO() sf.write(buffer, audio, sample_rate, format='wav') buffer.seek(0) logger.info(f"Buffer size: {buffer.getbuffer().nbytes} bytes") # Return the audio file directly to the client return StreamingResponse(buffer, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=synthesized_audio.wav"}) except ValueError as ve: logger.error(f"ValueError in synthesize_speech: {str(ve)}", exc_info=True) processing_time = time.time() - start_time return JSONResponse( status_code=400, content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time} ) except Exception as e: logger.error(f"Unexpected error in synthesize_speech: {str(e)}", exc_info=True) error_details = { "error": str(e), "type": type(e).__name__, "traceback": traceback.format_exc() } processing_time = time.time() - start_time return JSONResponse( status_code=500, content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time} ) finally: logger.info("Synthesize request completed") @app.post("/identify") async def identify_language( file: UploadFile = File(...), api_key: str = Depends(verify_api_key) ): start_time = time.time() try: # Read the uploaded file input_bytes = await file.read() # Save the file temporarily with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file: temp_file.write(input_bytes) temp_file_path = temp_file.name # Identify the language result = identify(temp_file_path) processing_time = time.time() - start_time # Clean up the temporary file os.unlink(temp_file_path) return JSONResponse(content={"language_identification": result, "processing_time_seconds": processing_time}) except Exception as e: logger.error(f"Error in identify_language: {str(e)}", exc_info=True) error_details = { "error": str(e), "traceback": traceback.format_exc() } processing_time = time.time() - start_time return JSONResponse( status_code=500, content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time} ) @app.get("/asr_languages") async def get_asr_languages(api_key: str = Depends(verify_api_key)): start_time = time.time() try: processing_time = time.time() - start_time return JSONResponse(content={"languages": ASR_LANGUAGES, "processing_time_seconds": processing_time}) except Exception as e: logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True) error_details = { "error": str(e), "traceback": traceback.format_exc() } processing_time = time.time() - start_time return JSONResponse( status_code=500, content={"message": "An error occurred while fetching ASR languages", "details": error_details, "processing_time_seconds": processing_time} ) @app.get("/tts_languages") async def get_tts_languages(api_key: str = Depends(verify_api_key)): start_time = time.time() try: processing_time = time.time() - start_time return JSONResponse(content={"languages": TTS_LANGUAGES, "processing_time_seconds": processing_time}) except Exception as e: logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True) error_details = { "error": str(e), "traceback": traceback.format_exc() } processing_time = time.time() - start_time return JSONResponse( status_code=500, content={"message": "An error occurred while fetching TTS languages", "details": error_details, "processing_time_seconds": processing_time} ) @app.get("/lid_languages") async def get_lid_languages(api_key: str = Depends(verify_api_key)): start_time = time.time() try: processing_time = time.time() - start_time return JSONResponse(content={"languages": LID_LANGUAGES, "processing_time_seconds": processing_time}) except Exception as e: logger.error(f"Error in get_lid_languages: {str(e)}", exc_info=True) error_details = { "error": str(e), "traceback": traceback.format_exc() } processing_time = time.time() - start_time return JSONResponse( status_code=500, content={"message": "An error occurred while fetching LID languages", "details": error_details, "processing_time_seconds": processing_time} )