from fastapi import FastAPI, HTTPException from pydantic import BaseModel from fastapi.responses import JSONResponse import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer import os import re import logging app = FastAPI() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set the cache directory for Hugging Face os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache') # Load model and tokenizer model_name = "Bijoy09/MObilebert" try: model = AutoModelForSequenceClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info("Model and tokenizer loaded successfully") except Exception as e: logger.error(f"Failed to load model or tokenizer: {e}") raise RuntimeError(f"Failed to load model or tokenizer: {e}") class TextRequest(BaseModel): text: str class BatchTextRequest(BaseModel): texts: list[str] # Regular expression to detect Bangla characters bangla_regex = re.compile('[\u0980-\u09FF]') def contains_bangla(text): return bool(bangla_regex.search(text)) @app.post("/batch_predict/") async def batch_predict(request: BatchTextRequest): try: model.eval() # Prepare the batch results results = [] for idx, text in enumerate(request.texts): # Check if text contains Bangla characters if not contains_bangla(text): results.append({"id": idx + 1, "text": text, "prediction": "other"}) continue # Encode and predict for texts containing Bangla characters inputs = tokenizer.encode_plus( text, add_special_tokens=True, max_length=64, truncation=True, padding='max_length', return_attention_mask=True, return_tensors='pt' ) with torch.no_grad(): logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits prediction = torch.argmax(logits, dim=1).item() label = "Spam" if prediction == 1 else "Ham" results.append({"id": idx + 1, "text": text, "prediction": label}) logger.info(f"Batch prediction results: {results}") return JSONResponse(content={"results": results}, media_type="application/json; charset=utf-8") except Exception as e: logger.error(f"Batch prediction failed: {e}") raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.") @app.get("/") async def root(): return {"message": "Welcome to the MobileBERT API"}