Update app.py
Browse files
app.py
CHANGED
@@ -3,17 +3,18 @@ from pydantic import BaseModel
|
|
3 |
import torch
|
4 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
5 |
import os
|
|
|
6 |
import logging
|
7 |
-
|
8 |
app = FastAPI()
|
9 |
-
|
10 |
# Configure logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
13 |
-
|
14 |
# Set the cache directory for Hugging Face
|
15 |
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
|
16 |
-
|
17 |
# Load model and tokenizer
|
18 |
model_name = "Bijoy09/MObilebert"
|
19 |
try:
|
@@ -23,99 +24,58 @@ try:
|
|
23 |
except Exception as e:
|
24 |
logger.error(f"Failed to load model or tokenizer: {e}")
|
25 |
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
|
26 |
-
|
27 |
class TextRequest(BaseModel):
|
28 |
text: str
|
29 |
-
|
30 |
class BatchTextRequest(BaseModel):
|
31 |
texts: list[str]
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
request.text,
|
40 |
-
add_special_tokens=True,
|
41 |
-
max_length=64,
|
42 |
-
truncation=True,
|
43 |
-
padding='max_length',
|
44 |
-
return_attention_mask=True,
|
45 |
-
return_tensors='pt'
|
46 |
-
)
|
47 |
-
with torch.no_grad():
|
48 |
-
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
49 |
-
prediction = torch.argmax(logits, dim=1).item()
|
50 |
-
return {"prediction": "Spam" if prediction == 1 else "Ham"}
|
51 |
-
except Exception as e:
|
52 |
-
logger.error(f"Prediction failed: {e}")
|
53 |
-
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
|
54 |
-
|
55 |
-
|
56 |
-
# @app.post("/batch_predict/")
|
57 |
-
# async def batch_predict(request: BatchTextRequest):
|
58 |
-
# try:
|
59 |
-
# model.eval()
|
60 |
-
# logger.info(f"Received batch prediction request for {len(request.texts)} texts")
|
61 |
-
# inputs = tokenizer(
|
62 |
-
# request.texts,
|
63 |
-
# add_special_tokens=True,
|
64 |
-
# max_length=64,
|
65 |
-
# truncation=True,
|
66 |
-
# padding='max_length',
|
67 |
-
# return_attention_mask=True,
|
68 |
-
# return_tensors='pt'
|
69 |
-
# )
|
70 |
-
|
71 |
-
# with torch.no_grad():
|
72 |
-
# logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
73 |
-
# predictions = torch.argmax(logits, dim=1).tolist()
|
74 |
-
|
75 |
-
# results = [
|
76 |
-
# {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
|
77 |
-
# for idx, (text, pred) in enumerate(zip(request.texts, predictions))
|
78 |
-
# ]
|
79 |
-
# logger.info(f"Batch prediction results: {results}")
|
80 |
-
# return {"results": results}
|
81 |
-
# except Exception as e:
|
82 |
-
# logger.error(f"Batch prediction failed: {e}")
|
83 |
-
# raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
|
84 |
-
|
85 |
-
|
86 |
@app.post("/batch_predict/")
|
87 |
async def batch_predict(request: BatchTextRequest):
|
88 |
try:
|
89 |
model.eval()
|
90 |
-
|
91 |
-
#
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
logger.info(f"Batch prediction results: {results}")
|
113 |
return {"results": results}
|
114 |
-
|
115 |
except Exception as e:
|
116 |
-
|
117 |
raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
|
118 |
-
|
119 |
@app.get("/")
|
120 |
async def root():
|
121 |
-
return {"message": "Welcome to the MobileBERT API"}
|
|
|
3 |
import torch
|
4 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
5 |
import os
|
6 |
+
import re
|
7 |
import logging
|
8 |
+
|
9 |
app = FastAPI()
|
10 |
+
|
11 |
# Configure logging
|
12 |
logging.basicConfig(level=logging.INFO)
|
13 |
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
# Set the cache directory for Hugging Face
|
16 |
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
|
17 |
+
|
18 |
# Load model and tokenizer
|
19 |
model_name = "Bijoy09/MObilebert"
|
20 |
try:
|
|
|
24 |
except Exception as e:
|
25 |
logger.error(f"Failed to load model or tokenizer: {e}")
|
26 |
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
|
27 |
+
|
28 |
class TextRequest(BaseModel):
|
29 |
text: str
|
30 |
+
|
31 |
class BatchTextRequest(BaseModel):
|
32 |
texts: list[str]
|
33 |
+
|
34 |
+
# Regular expression to detect Bangla characters
|
35 |
+
bangla_regex = re.compile('[\u0980-\u09FF]')
|
36 |
+
|
37 |
+
def contains_bangla(text):
|
38 |
+
return bool(bangla_regex.search(text))
|
39 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
@app.post("/batch_predict/")
|
41 |
async def batch_predict(request: BatchTextRequest):
|
42 |
try:
|
43 |
model.eval()
|
44 |
+
|
45 |
+
# Prepare the batch results
|
46 |
+
results = []
|
47 |
+
|
48 |
+
for idx, text in enumerate(request.texts):
|
49 |
+
logger.info(f" texts: {text}")
|
50 |
+
# Check if text contains Bangla characters
|
51 |
+
if not contains_bangla(text):
|
52 |
+
results.append({"id": idx + 1, "text": text, "prediction": "other"})
|
53 |
+
continue
|
54 |
+
|
55 |
+
# Encode and predict for texts containing Bangla characters
|
56 |
+
inputs = tokenizer.encode_plus(
|
57 |
+
text,
|
58 |
+
add_special_tokens=True,
|
59 |
+
max_length=64,
|
60 |
+
truncation=True,
|
61 |
+
padding='max_length',
|
62 |
+
return_attention_mask=True,
|
63 |
+
return_tensors='pt'
|
64 |
+
)
|
65 |
+
|
66 |
+
with torch.no_grad():
|
67 |
+
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
68 |
+
prediction = torch.argmax(logits, dim=1).item()
|
69 |
+
label = "Spam" if prediction == 1 else "Ham"
|
70 |
+
results.append({"id": idx + 1, "text": text, "prediction": label})
|
71 |
+
|
72 |
logger.info(f"Batch prediction results: {results}")
|
73 |
return {"results": results}
|
74 |
+
|
75 |
except Exception as e:
|
76 |
+
logger.error(f"Batch prediction failed: {e}")
|
77 |
raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
|
78 |
+
|
79 |
@app.get("/")
|
80 |
async def root():
|
81 |
+
return {"message": "Welcome to the MobileBERT API"}
|