from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer import os app = FastAPI() # Set the cache directory for Hugging Face os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache') # Load model and tokenizer model_name = "Bijoy09/MObilebert" try: model = AutoModelForSequenceClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) except Exception as e: raise RuntimeError(f"Failed to load model or tokenizer: {e}") class TextRequest(BaseModel): text: str class BatchTextRequest(BaseModel): texts: list[str] @app.post("/predict") async def predict(request: TextRequest): try: model.eval() inputs = tokenizer.encode_plus( request.text, add_special_tokens=True, max_length=64, truncation=True, padding='max_length', return_attention_mask=True, return_tensors='pt' ) with torch.no_grad(): logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits prediction = torch.argmax(logits, dim=1).item() return {"prediction": "Spam" if prediction == 1 else "Ham"} except Exception as e: raise HTTPException(status_code=500, detail=f"Prediction failed: {e}") @app.post("/batch_predict") async def batch_predict(request: BatchTextRequest): try: model.eval() results = [] for idx, text in enumerate(request.texts): inputs = tokenizer.encode_plus( text, add_special_tokens=True, max_length=64, truncation=True, padding='max_length', return_attention_mask=True, return_tensors='pt' ) with torch.no_grad(): logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits prediction = torch.argmax(logits, dim=1).item() results.append({"id": idx + 1, "text": text, "prediction": "Spam" if prediction == 1 else "Ham"}) return {"results": results} except Exception as e: raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}") # @app.post("/batch_predict") # async def batch_predict(request: BatchTextRequest): # try: # model.eval() # # Batch encode all texts in the request at once # inputs = tokenizer( # request.texts, # add_special_tokens=True, # max_length=64, # truncation=True, # padding='max_length', # return_attention_mask=True, # return_tensors='pt' # ) # # Run batch inference # with torch.no_grad(): # logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits # predictions = torch.argmax(logits, dim=1).tolist() # # Format results # results = [ # {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"} # for idx, (text, pred) in enumerate(zip(request.texts, predictions)) # ] # return {"results": results} # except Exception as e: # logging.error(f"Batch prediction failed: {e}") # raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.") @app.get("/") async def root(): return {"message": "Welcome to the MobileBERT API"}