File size: 3,643 Bytes
3f6c2a7
 
 
 
cd6d890
3f6c2a7
 
 
cd6d890
 
 
3f6c2a7
384e68d
cd6d890
 
 
 
 
3f6c2a7
 
 
 
6632a1b
 
 
46cd2b9
3adf7f1
3f6c2a7
cd6d890
 
 
 
 
9639e29
cd6d890
 
 
 
 
 
 
 
 
 
 
 
46cd2b9
6632a1b
 
 
 
2688a7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714fa8b
2688a7f
 
 
 
 
 
 
 
 
 
714fa8b
2688a7f
 
 
 
714fa8b
2688a7f
 
 
 
 
714fa8b
2688a7f
714fa8b
2688a7f
 
 
6632a1b
41034fa
cd6d890
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import os

app = FastAPI()

# Set the cache directory for Hugging Face
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')

# Load model and tokenizer
model_name = "Bijoy09/MObilebert"
try:
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
    raise RuntimeError(f"Failed to load model or tokenizer: {e}")

class TextRequest(BaseModel):
    text: str

class BatchTextRequest(BaseModel):
    texts: list[str]


@app.post("/predict")
async def predict(request: TextRequest):
    try:
        model.eval()
        inputs = tokenizer.encode_plus(
            request.text,
            add_special_tokens=True,
            max_length=64,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt'
        )
        with torch.no_grad():
            logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
            prediction = torch.argmax(logits, dim=1).item()
        return {"prediction": "Spam" if prediction == 1 else "Ham"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")


@app.post("/batch_predict")
async def batch_predict(request: BatchTextRequest):
    try:
        model.eval()
        results = []
        for idx, text in enumerate(request.texts):
            inputs = tokenizer.encode_plus(
                text,
                add_special_tokens=True,
                max_length=64,
                truncation=True,
                padding='max_length',
                return_attention_mask=True,
                return_tensors='pt'
            )
            with torch.no_grad():
                logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
                prediction = torch.argmax(logits, dim=1).item()
                results.append({"id": idx + 1, "text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
        return {"results": results}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
# @app.post("/batch_predict")
# async def batch_predict(request: BatchTextRequest):
#     try:
#         model.eval()
        
#         # Batch encode all texts in the request at once
#         inputs = tokenizer(
#             request.texts,
#             add_special_tokens=True,
#             max_length=64,
#             truncation=True,
#             padding='max_length',
#             return_attention_mask=True,
#             return_tensors='pt'
#         )

#         # Run batch inference
#         with torch.no_grad():
#             logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
#             predictions = torch.argmax(logits, dim=1).tolist()
        
#         # Format results
#         results = [
#             {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
#             for idx, (text, pred) in enumerate(zip(request.texts, predictions))
#         ]
        
#         return {"results": results}

#     except Exception as e:
#         logging.error(f"Batch prediction failed: {e}")
#         raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")


@app.get("/")
async def root():
    return {"message": "Welcome to the MobileBERT API"}