Bijoy09's picture
Update app.py
3adf7f1 verified
raw
history blame
2.14 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import os
import logging
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set the cache directory for Hugging Face
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and tokenizer
model_name = "Bijoy09/MObilebert"
try:
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {e}")
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
class TextRequest(BaseModel):
text: str
@app.post("/predict/")
@app.post("/predict")
async def predict(request: TextRequest):
try:
logger.info(f"Received text: {request.text}")
model.eval()
inputs = tokenizer.encode_plus(
request.text,
add_special_tokens=True,
max_length=64,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt'
)
logger.info(f"Tokenized inputs: {inputs}")
with torch.no_grad():
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
logger.info(f"Model logits: {logits}")
prediction = torch.argmax(logits, dim=1).item()
return {"prediction": "Spam" if prediction == 1 else "Ham"}
except Exception as e:
logger.error(f"Prediction failed: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
@app.get("/")
async def root():
return {"message": "Welcome to the MobileBERT API"}