File size: 3,019 Bytes
3f6c2a7
 
d1b5757
3f6c2a7
 
cd6d890
0d35220
473261b
1e28e60
3f6c2a7
1e28e60
473261b
 
1e28e60
 
cd6d890
 
1e28e60
3f6c2a7
384e68d
cd6d890
 
 
473261b
cd6d890
473261b
cd6d890
1e28e60
3f6c2a7
 
1e28e60
6632a1b
 
1e28e60
0d35220
 
1e28e60
0d35220
 
1e28e60
 
 
 
d2281db
6632a1b
 
 
1e28e60
0d35220
 
1e28e60
0d35220
 
1e28e60
0d35220
 
 
 
1e28e60
 
 
 
 
0d35220
 
1e28e60
0d35220
 
 
 
 
 
 
1e28e60
0d35220
 
 
 
 
1e28e60
eb64215
1b1a84c
1e28e60
efcb399
0d35220
efcb399
1e28e60
cd6d890
 
0d35220
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import JSONResponse
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import os
import re
import logging

app = FastAPI()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(_name_)

# Set the cache directory for Hugging Face
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')

# Load model and tokenizer
model_name = "Bijoy09/MObilebert"
try:
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    logger.info("Model and tokenizer loaded successfully")
except Exception as e:
    logger.error(f"Failed to load model or tokenizer: {e}")
    raise RuntimeError(f"Failed to load model or tokenizer: {e}")

class TextRequest(BaseModel):
    text: str

class BatchTextRequest(BaseModel):
    texts: list[str]

# Regular expression to detect Bangla characters
bangla_regex = re.compile('[\u0980-\u09FF]')

def contains_bangla(text):
    return bool(bangla_regex.search(text))

def remove_non_bangla(text):
    return ''.join(bangla_regex.findall(text))

@app.post("/batch_predict/")
async def batch_predict(request: BatchTextRequest):
    try:
        model.eval()

        # Prepare the batch results
        results = []

        for idx, text in enumerate(request.texts):
            logger.info(f" texts: {text}")

            # Check if text contains Bangla characters
            if not contains_bangla(text):
                results.append({"id": idx + 1, "text": text, "prediction": "other"})
                continue

            # Remove non-Bangla characters
            modified_text = remove_non_bangla(text)
            ogger.info(f"modified text: {modified_text}")

            # Encode and predict for texts containing Bangla characters
            inputs = tokenizer.encode_plus(
                modified_text,
                add_special_tokens=True,
                max_length=64,
                truncation=True,
                padding='max_length',
                return_attention_mask=True,
                return_tensors='pt'
            )

            with torch.no_grad():
                logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
                prediction = torch.argmax(logits, dim=1).item()
                label = "Spam" if prediction == 1 else "Ham"
                results.append({"id": idx + 1, "text": text, "prediction": label})

        logger.info(f"Batch prediction results: {results}")
        return JSONResponse(content={"results": results}, media_type="application/json; charset=utf-8")

    except Exception as e:
        logger.error(f"Batch prediction failed: {e}")
        raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")

@app.get("/")
async def root():
    return {"message": "Welcome to the MobileBERT API"}