from fastapi import FastAPI, File, UploadFile, Form from fastapi import HTTPException import uvicorn from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import torch import ffmpeg import io import logging from flores200_codes import flores_codes import nltk import librosa import json nltk.download("punkt") nltk.download('punkt_tab') app = FastAPI() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load Whisper model from Hugging Face model_name = "openai/whisper-base" device = 0 if torch.cuda.is_available() else -1 # Use GPU if available whisper_pipeline = pipeline("automatic-speech-recognition", model=model_name, device=device) # set up translation pipeline def get_translation_pipeline(translation_model_path): model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_path) tokenizer = AutoTokenizer.from_pretrained(translation_model_path) translation_pipeline = pipeline('translation', model=model, tokenizer=tokenizer, device=device) return translation_pipeline translator = get_translation_pipeline("mutisya/nllb_600m-en-kik-kam-luo-mer-som-swh-drL-24_5-filtered-v24_28_4") def load_tts_model(model_id): model_pipeline = pipeline("text-to-speech", model=model_id, device=device) return model_pipeline def initialize_tts_pipelines(load_models=False): global tts_config_settings global tts_pipelines with open(f"tts_models_config.json") as f: tts_config_settings = json.loads(f.read()) for lang, lang_config in tts_config_settings.items(): if lang in tts_preload_languages or load_models: tts_pipelines[lang] = load_tts_model(lang_config["model_repo"]) def ensure_tts_pipeline_loaded(lang_code): global tts_config_settings global tts_pipelines if lang_code in tts_pipelines: pipeline = tts_pipelines[lang_code] else: lang_config = tts_config_settings[lang_code] tts_pipelines[lang_code] = load_tts_model(lang_config["model_repo"]) def load_asr_model(model_id): model_pipeline = pipeline("automatic-speech-recognition", model=model_id, device=device) return model_pipeline def initialize_asr_pipelines(load_models=False): global asr_config_settings global asr_pipelines with open(f"asr_models_config.json") as f: asr_config_settings = json.loads(f.read()) # iterate through config languge entries and load model for each into a dictionary for lang, lang_config in asr_config_settings.items(): if lang in asr_preload_languages or load_models: asr_pipelines[lang] = load_asr_model(lang_config["model_repo"]) def ensure_pipeline_loaded(lang_code): global asr_config_settings global asr_pipelines if lang_code in asr_pipelines: pipeline = asr_pipelines[lang_code] else: lang_config = asr_config_settings[lang_code] asr_pipelines[lang_code] = load_asr_model(lang_config["model_repo"]) class RecognitionResponse(BaseModel): text: str #@app.post("/recognize", response_model=RecognitionResponse) @app.post("/recognize") async def recognize_audio(audio: UploadFile = File(...), language: str = Form("en")): try: # Read audio data audio_bytes = await audio.read() # Convert audio bytes to WAV format if needed try: input_audio = ffmpeg.input('pipe:0') audio_data, _ = ( input_audio.output('pipe:1', format='wav') .run(input=audio_bytes, capture_stdout=True, capture_stderr=True) ) except ffmpeg.Error as e: logger.error("FFmpeg error while converting audio data", exc_info=True) raise HTTPException(status_code=400, detail="Invalid audio format") # Run Whisper model on the audio language = language.strip('\"') result = whisper_pipeline(audio_data, chunk_length_s=30, generate_kwargs={"language": language , "task": "transcribe"}, return_timestamps="word") # Extract transcription text transcription = result["text"] logger.info(f"Transcription successful: {transcription}") segments = result["chunks"] logger.info(segments) # return RecognitionResponse(text=transcription, chunks=curr_chunks) transcription_result = [] for segment in segments: transcription_result.append({ "word": segment['text'], "startTime": round(segment['timestamp'][0],1), "endTime": round(segment['timestamp'][1],1) }) logger.info(transcription_result) return {"text":transcription, "chunks": transcription_result} # return {"chunks": transcription_result} except Exception as e: logger.error("Unexpected error during transcription", exc_info=True) raise HTTPException(status_code=500, detail="Internal Server Error") class TranslationRequest(BaseModel): text: str sourceLanguage: str targetLanguage: str class TranslationResponse(BaseModel): translatedText: str @app.post("/translate", response_model=TranslationResponse) async def translate_text(request: TranslationRequest): source_language = request.sourceLanguage target_language = request.targetLanguage text_to_translate = request.text try: src_lang = flores_codes[source_language] tgt_lang = flores_codes[target_language] translated_text = translator(text_to_translate, src_lang=src_lang,tgt_lang=tgt_lang)[0]['translation_text'] return TranslationResponse(translatedText=translated_text) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) class TTSRequest(BaseModel): language: str text: str class TTSResponse(BaseModel): audioBytes: str tts_config_settings = {} tts_pipelines={} tts_preload_languages=["kik"] @app.post("/text-to-speech", response_model=TTSResponse) async def text_to_speech(request: TTSRequest): """ Convert the given text to speech and return the audio data in Base64 format. """ text = request.text.strip() language = request.language.strip() if not text: raise HTTPException(status_code=400, detail="Input text is empty") try: # Generate speech using the TTS pipeline print("Generating speech...") ensure_tts_pipeline_loaded(language) tts_pipeline = tts_pipelines[language] sentences = nltk.sent_tokenize(text) for sentence_text in sentences: #audio = synthesiser(sentence_text) audio = pipe(sentence_text) if stream is not None: stream = (stream[0] , np.concatenate([stream[1], audio["audio"].T ])) else: stream = (audio["sampling_rate"], audio["audio"].T) return stream audio = tts_pipeline(text, return_tensors=True)["waveform"] sample_rate = 22050 # Default sample rate for the espnet model # Save the audio to a BytesIO buffer as a WAV file buffer = io.BytesIO() sf.write(buffer, audio.squeeze().numpy(), sample_rate, format="WAV") buffer.seek(0) # Encode the audio as Base64 audio_bytes = base64.b64encode(buffer.read()).decode("utf-8") return TTSResponse(audioBytes=audio_bytes) except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}") # Run the FastAPI application if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)