Bijoy09's picture
Update app.py
eb64215 verified
raw
history blame
4.15 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import os
import logging
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set the cache directory for Hugging Face
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
# Load model and tokenizer
model_name = "Bijoy09/MObilebert"
try:
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {e}")
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
class TextRequest(BaseModel):
text: str
class BatchTextRequest(BaseModel):
texts: list[str]
@app.post("/predict/")
async def predict(request: TextRequest):
try:
model.eval()
inputs = tokenizer.encode_plus(
request.text,
add_special_tokens=True,
max_length=64,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt'
)
with torch.no_grad():
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
prediction = torch.argmax(logits, dim=1).item()
return {"prediction": "Spam" if prediction == 1 else "Ham"}
except Exception as e:
logger.error(f"Prediction failed: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
# @app.post("/batch_predict/")
# async def batch_predict(request: BatchTextRequest):
# try:
# model.eval()
# logger.info(f"Received batch prediction request for {len(request.texts)} texts")
# inputs = tokenizer(
# request.texts,
# add_special_tokens=True,
# max_length=64,
# truncation=True,
# padding='max_length',
# return_attention_mask=True,
# return_tensors='pt'
# )
# with torch.no_grad():
# logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
# predictions = torch.argmax(logits, dim=1).tolist()
# results = [
# {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
# for idx, (text, pred) in enumerate(zip(request.texts, predictions))
# ]
# logger.info(f"Batch prediction results: {results}")
# return {"results": results}
# except Exception as e:
# logger.error(f"Batch prediction failed: {e}")
# raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
@app.post("/batch_predict/")
async def batch_predict(request: BatchTextRequest):
try:
model.eval()
# Batch encode all texts in the request at once
inputs = tokenizer(
request.texts,
add_special_tokens=True,
max_length=64,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt'
)
# Run batch inference
with torch.no_grad():
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
predictions = torch.argmax(logits, dim=1).tolist()
# Format results
results = [
{"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
for idx, (text, pred) in enumerate(zip(request.texts, predictions))
]
logger.info(f"Batch prediction results: {results}")
return {"results": results}
except Exception as e:
logging.error(f"Batch prediction failed: {e}")
raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
@app.get("/")
async def root():
return {"message": "Welcome to the MobileBERT API"}