File size: 3,643 Bytes
3f6c2a7 cd6d890 3f6c2a7 cd6d890 3f6c2a7 384e68d cd6d890 3f6c2a7 6632a1b 46cd2b9 3adf7f1 3f6c2a7 cd6d890 9639e29 cd6d890 46cd2b9 6632a1b 2688a7f 714fa8b 2688a7f 714fa8b 2688a7f 714fa8b 2688a7f 714fa8b 2688a7f 714fa8b 2688a7f 6632a1b 41034fa cd6d890 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import os
app = FastAPI()
# Set the cache directory for Hugging Face
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
# Load model and tokenizer
model_name = "Bijoy09/MObilebert"
try:
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
class TextRequest(BaseModel):
text: str
class BatchTextRequest(BaseModel):
texts: list[str]
@app.post("/predict")
async def predict(request: TextRequest):
try:
model.eval()
inputs = tokenizer.encode_plus(
request.text,
add_special_tokens=True,
max_length=64,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt'
)
with torch.no_grad():
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
prediction = torch.argmax(logits, dim=1).item()
return {"prediction": "Spam" if prediction == 1 else "Ham"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
@app.post("/batch_predict")
async def batch_predict(request: BatchTextRequest):
try:
model.eval()
results = []
for idx, text in enumerate(request.texts):
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=64,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt'
)
with torch.no_grad():
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
prediction = torch.argmax(logits, dim=1).item()
results.append({"id": idx + 1, "text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
return {"results": results}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
# @app.post("/batch_predict")
# async def batch_predict(request: BatchTextRequest):
# try:
# model.eval()
# # Batch encode all texts in the request at once
# inputs = tokenizer(
# request.texts,
# add_special_tokens=True,
# max_length=64,
# truncation=True,
# padding='max_length',
# return_attention_mask=True,
# return_tensors='pt'
# )
# # Run batch inference
# with torch.no_grad():
# logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
# predictions = torch.argmax(logits, dim=1).tolist()
# # Format results
# results = [
# {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
# for idx, (text, pred) in enumerate(zip(request.texts, predictions))
# ]
# return {"results": results}
# except Exception as e:
# logging.error(f"Batch prediction failed: {e}")
# raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
@app.get("/")
async def root():
return {"message": "Welcome to the MobileBERT API"}
|