Update app.py
Browse files
app.py
CHANGED
@@ -6,16 +6,16 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
6 |
import os
|
7 |
import re
|
8 |
import logging
|
9 |
-
|
10 |
app = FastAPI()
|
11 |
-
|
12 |
# Configure logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
logger = logging.getLogger(__name__)
|
15 |
-
|
16 |
# Set the cache directory for Hugging Face
|
17 |
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
|
18 |
-
|
19 |
# Load model and tokenizer
|
20 |
model_name = "Bijoy09/MObilebert"
|
21 |
try:
|
@@ -25,45 +25,37 @@ try:
|
|
25 |
except Exception as e:
|
26 |
logger.error(f"Failed to load model or tokenizer: {e}")
|
27 |
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
|
28 |
-
|
29 |
class TextRequest(BaseModel):
|
30 |
text: str
|
31 |
-
|
32 |
class BatchTextRequest(BaseModel):
|
33 |
texts: list[str]
|
34 |
-
|
35 |
# Regular expression to detect Bangla characters
|
36 |
bangla_regex = re.compile('[\u0980-\u09FF]')
|
37 |
-
|
38 |
def contains_bangla(text):
|
39 |
return bool(bangla_regex.search(text))
|
40 |
-
|
41 |
-
def remove_non_bangla(text):
|
42 |
-
return ''.join(bangla_regex.findall(text))
|
43 |
-
|
44 |
@app.post("/batch_predict/")
|
45 |
async def batch_predict(request: BatchTextRequest):
|
46 |
try:
|
47 |
model.eval()
|
48 |
-
|
49 |
# Prepare the batch results
|
50 |
results = []
|
51 |
-
|
52 |
for idx, text in enumerate(request.texts):
|
53 |
-
|
54 |
-
|
55 |
# Check if text contains Bangla characters
|
56 |
if not contains_bangla(text):
|
57 |
results.append({"id": idx + 1, "text": text, "prediction": "other"})
|
58 |
continue
|
59 |
-
|
60 |
-
# Remove non-Bangla characters
|
61 |
-
modified_text = remove_non_bangla(text)
|
62 |
-
logger.info(f"modified text: {modified_text}")
|
63 |
-
|
64 |
# Encode and predict for texts containing Bangla characters
|
65 |
inputs = tokenizer.encode_plus(
|
66 |
-
|
67 |
add_special_tokens=True,
|
68 |
max_length=64,
|
69 |
truncation=True,
|
@@ -71,20 +63,20 @@ async def batch_predict(request: BatchTextRequest):
|
|
71 |
return_attention_mask=True,
|
72 |
return_tensors='pt'
|
73 |
)
|
74 |
-
|
75 |
with torch.no_grad():
|
76 |
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
77 |
prediction = torch.argmax(logits, dim=1).item()
|
78 |
label = "Spam" if prediction == 1 else "Ham"
|
79 |
results.append({"id": idx + 1, "text": text, "prediction": label})
|
80 |
-
|
81 |
logger.info(f"Batch prediction results: {results}")
|
82 |
return JSONResponse(content={"results": results}, media_type="application/json; charset=utf-8")
|
83 |
-
|
84 |
except Exception as e:
|
85 |
logger.error(f"Batch prediction failed: {e}")
|
86 |
raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
|
87 |
-
|
88 |
@app.get("/")
|
89 |
async def root():
|
90 |
return {"message": "Welcome to the MobileBERT API"}
|
|
|
6 |
import os
|
7 |
import re
|
8 |
import logging
|
9 |
+
|
10 |
app = FastAPI()
|
11 |
+
|
12 |
# Configure logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
# Set the cache directory for Hugging Face
|
17 |
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
|
18 |
+
|
19 |
# Load model and tokenizer
|
20 |
model_name = "Bijoy09/MObilebert"
|
21 |
try:
|
|
|
25 |
except Exception as e:
|
26 |
logger.error(f"Failed to load model or tokenizer: {e}")
|
27 |
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
|
28 |
+
|
29 |
class TextRequest(BaseModel):
|
30 |
text: str
|
31 |
+
|
32 |
class BatchTextRequest(BaseModel):
|
33 |
texts: list[str]
|
34 |
+
|
35 |
# Regular expression to detect Bangla characters
|
36 |
bangla_regex = re.compile('[\u0980-\u09FF]')
|
37 |
+
|
38 |
def contains_bangla(text):
|
39 |
return bool(bangla_regex.search(text))
|
40 |
+
|
|
|
|
|
|
|
41 |
@app.post("/batch_predict/")
|
42 |
async def batch_predict(request: BatchTextRequest):
|
43 |
try:
|
44 |
model.eval()
|
45 |
+
|
46 |
# Prepare the batch results
|
47 |
results = []
|
48 |
+
|
49 |
for idx, text in enumerate(request.texts):
|
50 |
+
|
|
|
51 |
# Check if text contains Bangla characters
|
52 |
if not contains_bangla(text):
|
53 |
results.append({"id": idx + 1, "text": text, "prediction": "other"})
|
54 |
continue
|
55 |
+
|
|
|
|
|
|
|
|
|
56 |
# Encode and predict for texts containing Bangla characters
|
57 |
inputs = tokenizer.encode_plus(
|
58 |
+
text,
|
59 |
add_special_tokens=True,
|
60 |
max_length=64,
|
61 |
truncation=True,
|
|
|
63 |
return_attention_mask=True,
|
64 |
return_tensors='pt'
|
65 |
)
|
66 |
+
|
67 |
with torch.no_grad():
|
68 |
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
69 |
prediction = torch.argmax(logits, dim=1).item()
|
70 |
label = "Spam" if prediction == 1 else "Ham"
|
71 |
results.append({"id": idx + 1, "text": text, "prediction": label})
|
72 |
+
|
73 |
logger.info(f"Batch prediction results: {results}")
|
74 |
return JSONResponse(content={"results": results}, media_type="application/json; charset=utf-8")
|
75 |
+
|
76 |
except Exception as e:
|
77 |
logger.error(f"Batch prediction failed: {e}")
|
78 |
raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
|
79 |
+
|
80 |
@app.get("/")
|
81 |
async def root():
|
82 |
return {"message": "Welcome to the MobileBERT API"}
|