File size: 2,766 Bytes
3f6c2a7 d1b5757 3f6c2a7 cd6d890 0d35220 473261b e729235 3f6c2a7 e729235 473261b 5e27eda e729235 cd6d890 e729235 3f6c2a7 384e68d cd6d890 473261b cd6d890 473261b cd6d890 e729235 3f6c2a7 e729235 6632a1b e729235 0d35220 e729235 0d35220 e729235 d2281db 6632a1b e729235 0d35220 e729235 0d35220 e729235 0d35220 e729235 0d35220 e729235 0d35220 e729235 0d35220 e729235 eb64215 1b1a84c e729235 efcb399 0d35220 efcb399 e729235 cd6d890 0d35220 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import JSONResponse
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import os
import re
import logging
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set the cache directory for Hugging Face
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
# Load model and tokenizer
model_name = "Bijoy09/MObilebert"
try:
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {e}")
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
class TextRequest(BaseModel):
text: str
class BatchTextRequest(BaseModel):
texts: list[str]
# Regular expression to detect Bangla characters
bangla_regex = re.compile('[\u0980-\u09FF]')
def contains_bangla(text):
return bool(bangla_regex.search(text))
@app.post("/batch_predict/")
async def batch_predict(request: BatchTextRequest):
try:
model.eval()
# Prepare the batch results
results = []
for idx, text in enumerate(request.texts):
# Check if text contains Bangla characters
if not contains_bangla(text):
results.append({"id": idx + 1, "text": text, "prediction": "other"})
continue
# Encode and predict for texts containing Bangla characters
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=64,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt'
)
with torch.no_grad():
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
prediction = torch.argmax(logits, dim=1).item()
label = "Spam" if prediction == 1 else "Ham"
results.append({"id": idx + 1, "text": text, "prediction": label})
logger.info(f"Batch prediction results: {results}")
return JSONResponse(content={"results": results}, media_type="application/json; charset=utf-8")
except Exception as e:
logger.error(f"Batch prediction failed: {e}")
raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
@app.get("/")
async def root():
return {"message": "Welcome to the MobileBERT API"} |