from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer import os import logging app = FastAPI() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set the cache directory for Hugging Face os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache') # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model and tokenizer model_name = "Bijoy09/MObilebert" try: model = AutoModelForSequenceClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info("Model and tokenizer loaded successfully") except Exception as e: logger.error(f"Failed to load model or tokenizer: {e}") raise RuntimeError(f"Failed to load model or tokenizer: {e}") class TextRequest(BaseModel): text: str @app.post("/predict/") @app.post("/predict") async def predict(request: TextRequest): try: logger.info(f"Received text: {request.text}") model.eval() inputs = tokenizer.encode_plus( request.text, add_special_tokens=True, max_length=64, truncation=True, padding='max_length', return_attention_mask=True, return_tensors='pt' ) logger.info(f"Tokenized inputs: {inputs}") with torch.no_grad(): logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits logger.info(f"Model logits: {logits}") prediction = torch.argmax(logits, dim=1).item() return {"prediction": "Spam" if prediction == 1 else "Ham"} except Exception as e: logger.error(f"Prediction failed: {e}") raise HTTPException(status_code=500, detail=f"Prediction failed: {e}") @app.get("/") async def root(): return {"message": "Welcome to the MobileBERT API"}