from fastapi import FastAPI, File, UploadFile, Form from fastapi import HTTPException import uvicorn from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import torch import ffmpeg import io import logging from flores200_codes import flores_codes import nltk import librosa import json import soundfile as sf import numpy as np import base64 from PIL import Image from io import BytesIO from transformers import TrOCRProcessor, VisionEncoderDecoderModel import torch import spaces import easyocr import numpy as np import cv2 import io import os from typing import List from fastapi.middleware.cors import CORSMiddleware # Ensure EasyOCR uses a directory with write permissions # os.environ["EASYOCR_CACHE_DIR"] = "/app/.EasyOCR" nltk.download("punkt") nltk.download('punkt_tab') app = FastAPI() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize EasyOCR reader globally for better performance # Global variable for the OCR reader reader = None #print("Loading EasyOCR model...") #reader = easyocr.Reader(['en'], gpu=True) # Set gpu=True if your environment supports it #print("Model loaded successfully.") @app.on_event("startup") async def startup_event(): """Initialize EasyOCR during startup""" global reader try: logger.info("Checking GPU availability...") if torch.cuda.is_available(): device = "cuda" gpu = True logger.info(f"GPU detected: {torch.cuda.get_device_name(0)}") else: device = "cpu" gpu = False logger.warning("No GPU detected, falling back to CPU") logger.info("Initializing EasyOCR and downloading models...") # Set download directory to ensure we know where models are stored model_storage_directory = os.path.join(os.getcwd(), "models") # model_storage_directory = "/app/.EasyOCR" logger.info("Creating models folder") os.makedirs(model_storage_directory, exist_ok=True) logger.info("Created temporary folder") # Download and initialize model reader = easyocr.Reader( ['en'], # model_storage_directory=model_storage_directory, # download_enabled=True, # Force download even if model exists gpu=gpu, # Enable GPU if available #detector=True, # Use CUDA detector #recognizer=True # Use CUDA recognizer ) logger.info(f"Initialized the reader. Testing operation using sample image") # Perform a small inference to ensure everything is loaded sample_image = np.zeros((100, 100), dtype=np.uint8) reader.readtext(sample_image, detail=0) logger.info(f"EasyOCR initialization completed successfully using {device.upper()}") except Exception as e: logger.error(f"Failed to initialize EasyOCR: {str(e)}") raise e # Load Whisper model from Hugging Face model_name = "openai/whisper-base" device = 0 if torch.cuda.is_available() else -1 # Use GPU if available # set up translation pipeline def get_translation_pipeline(translation_model_path): model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_path) tokenizer = AutoTokenizer.from_pretrained(translation_model_path) translation_pipeline = pipeline('translation', model=model, tokenizer=tokenizer, device=device) return translation_pipeline translator = get_translation_pipeline("mutisya/nllb_600m-en-kik-kam-luo-mer-som-swh-drL-24_5-filtered-v24_28_4") asr_config_settings = {} asr_pipelines={} asr_preload_languages=["eng"] def load_asr_model(model_id): model_pipeline = pipeline("automatic-speech-recognition", model=model_id, device=device) return model_pipeline def initialize_asr_pipelines(load_models=False): global asr_config_settings global asr_pipelines with open(f"asr_models_config.json") as f: asr_config_settings = json.loads(f.read()) # iterate through config languge entries and load model for each into a dictionary for lang, lang_config in asr_config_settings.items(): if lang in asr_preload_languages or load_models: asr_pipelines[lang] = load_asr_model(lang_config["model_repo"]) def ensure_asr_pipeline_loaded(lang_code): global asr_config_settings global asr_pipelines if lang_code in asr_pipelines: pipeline = asr_pipelines[lang_code] else: lang_config = asr_config_settings[lang_code] asr_pipelines[lang_code] = load_asr_model(lang_config["model_repo"]) class RecognitionResponse(BaseModel): text: str @spaces.GPU @app.post("/recognize") async def recognize_audio(audio: UploadFile = File(...), language: str = Form("en")): try: # Read audio data audio_bytes = await audio.read() # Convert audio bytes to WAV format if needed try: input_audio = ffmpeg.input('pipe:0') audio_data, _ = ( input_audio.output('pipe:1', format='wav') .run(input=audio_bytes, capture_stdout=True, capture_stderr=True) ) except ffmpeg.Error as e: logger.error("FFmpeg error while converting audio data", exc_info=True) raise HTTPException(status_code=400, detail="Invalid audio format") # Run Whisper model on the audio language = language.strip('\"') ensure_asr_pipeline_loaded(language) transcriber = asr_pipelines[language] result = transcriber(audio_data, return_timestamps="word") # Extract transcription text transcription = result["text"].capitalize() segments = result["chunks"] # return RecognitionResponse(text=transcription, chunks=curr_chunks) transcription_result = [] for segment in segments: transcription_result.append({ "word": segment['text'], "startTime": round(segment['timestamp'][0],1), "endTime": round(segment['timestamp'][1],1) }) return {"text":transcription, "chunks": transcription_result} # return {"chunks": transcription_result} except Exception as e: logger.error("Unexpected error during transcription", exc_info=True) raise HTTPException(status_code=500, detail="Internal Server Error") class TranslationRequest(BaseModel): text: str sourceLanguage: str targetLanguage: str class TranslationResponse(BaseModel): translatedText: str @spaces.GPU @app.post("/translate", response_model=TranslationResponse) async def translate_text(request: TranslationRequest): source_language = request.sourceLanguage target_language = request.targetLanguage text_to_translate = request.text try: src_lang = flores_codes[source_language] tgt_lang = flores_codes[target_language] translated_text = translator(text_to_translate, src_lang=src_lang,tgt_lang=tgt_lang)[0]['translation_text'] return TranslationResponse(translatedText=translated_text) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) class TTSRequest(BaseModel): language: str text: str class TTSResponse(BaseModel): audioBytes: str sampleRate: int tts_config_settings = {} tts_pipelines={} tts_preload_languages=["kik"] def load_tts_model(model_id): model_pipeline = pipeline("text-to-speech", model=model_id, device=device) return model_pipeline def initialize_tts_pipelines(load_models=False): global tts_config_settings global tts_pipelines with open(f"tts_models_config.json") as f: tts_config_settings = json.loads(f.read()) for lang, lang_config in tts_config_settings.items(): if lang in tts_preload_languages or load_models: tts_pipelines[lang] = load_tts_model(lang_config["model_repo"]) def ensure_tts_pipeline_loaded(lang_code): global tts_config_settings global tts_pipelines if lang_code in tts_pipelines: pipeline = tts_pipelines[lang_code] else: lang_config = tts_config_settings[lang_code] tts_pipelines[lang_code] = load_tts_model(lang_config["model_repo"]) @spaces.GPU @app.post("/text-to-speech", response_model=TTSResponse) async def text_to_speech(request: TTSRequest): """ Convert the given text to speech and return the audio data in Base64 format. """ print(request) text = request.text.strip() language = request.language.strip() if not text: raise HTTPException(status_code=400, detail="Input text is empty") try: # Generate speech using the TTS pipeline print("Generating speech...") ensure_tts_pipeline_loaded(language) tts_pipeline = tts_pipelines[language] #audio = tts_pipeline(text, return_tensors=True)["waveform"] result = tts_pipeline(text) audio_tensor = result["audio"] sample_rate = result.get("sampling_rate", 22050) # Convert the tensor to numpy array audio_16bit = audio_tensor.T # Save the audio to a BytesIO buffer as a WAV file buffer = io.BytesIO() sf.write(buffer, audio_16bit, sample_rate, format="WAV", subtype="PCM_16") buffer.seek(0) # Encode the audio as Base64 audio_bytes = base64.b64encode(buffer.read()).decode("utf-8") return TTSResponse(audioBytes=audio_bytes, sampleRate=sample_rate) except Exception as e: logger.error("Unexpected error during TTS ", exc_info=True) raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}") """ ocr_model_name = "microsoft/trocr-large-printed" ocr_processor = TrOCRProcessor.from_pretrained(ocr_model_name) ocr_model = VisionEncoderDecoderModel.from_pretrained(ocr_model_name) # Ensure we're using the appropriate device (GPU if available) ocr_device = "cuda" if torch.cuda.is_available() else "cpu" ocr_model.to(ocr_device) """ class OcrRequest(BaseModel): imageBase64: str # Base64-encoded image class OcrResponse(BaseModel): text: str def base64_to_image(base64_str: str) -> Image.Image: """Convert a base64 string to a PIL Image.""" try: image_data = base64.b64decode(base64_str) return Image.open(BytesIO(image_data)).convert("RGB") except Exception as e: raise ValueError("Invalid image data") @spaces.GPU @app.post("/ocr2", response_model=OcrResponse) async def process_ocr2(image: UploadFile = File(...)): try: # Read the uploaded file image_bytes = await image.read() # Convert bytes to PIL Image image = Image.open(BytesIO(image_bytes)).convert("RGB") pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) # Perform OCR using the model generated_ids = ocr_model.generate(pixel_values) text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print("Extracted text: "+ text) return OcrResponse(text=text.strip()) except ValueError as e: logger.error("Unexpected error during OCR ", exc_info=True) raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error("Unexpected error during OCR ", exc_info=True) raise HTTPException(status_code=500, detail="An error occurred during OCR processing") @spaces.GPU @app.post("/ocr", response_model=OcrResponse) async def process_ocr(image: UploadFile = File(...)): if reader is None: raise HTTPException(status_code=500, detail="OCR system not initialized") try: # Validate image file if not image.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") # Read image file contents = await image.read() nparr = np.frombuffer(contents, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: raise HTTPException(status_code=400, detail="Invalid image file") # Perform OCR results = reader.readtext(img, detail=0) text = results print("Extracted text: "+ text) return OcrResponse(text=text.strip()) except ValueError as e: logger.error("Unexpected error during OCR ", exc_info=True) raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error("Unexpected error during OCR ", exc_info=True) raise HTTPException(status_code=500, detail="An error occurred during OCR processing") # Optional: Add a health check endpoint @app.get("/health") async def health_check(): return {"status": "healthy"} # Run the FastAPI application initialize_tts_pipelines(True) initialize_asr_pipelines() if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)