Update app.py
Browse files
app.py
CHANGED
@@ -3,9 +3,14 @@ from pydantic import BaseModel
|
|
3 |
import torch
|
4 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
5 |
import os
|
|
|
6 |
|
7 |
app = FastAPI()
|
8 |
|
|
|
|
|
|
|
|
|
9 |
# Set the cache directory for Hugging Face
|
10 |
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
|
11 |
|
@@ -14,7 +19,9 @@ model_name = "Bijoy09/MObilebert"
|
|
14 |
try:
|
15 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
17 |
except Exception as e:
|
|
|
18 |
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
|
19 |
|
20 |
class TextRequest(BaseModel):
|
@@ -42,6 +49,7 @@ async def predict(request: TextRequest):
|
|
42 |
prediction = torch.argmax(logits, dim=1).item()
|
43 |
return {"prediction": "Spam" if prediction == 1 else "Ham"}
|
44 |
except Exception as e:
|
|
|
45 |
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
|
46 |
|
47 |
|
@@ -49,56 +57,30 @@ async def predict(request: TextRequest):
|
|
49 |
async def batch_predict(request: BatchTextRequest):
|
50 |
try:
|
51 |
model.eval()
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
67 |
return {"results": results}
|
68 |
except Exception as e:
|
|
|
69 |
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
|
70 |
-
# @app.post("/batch_predict")
|
71 |
-
# async def batch_predict(request: BatchTextRequest):
|
72 |
-
# try:
|
73 |
-
# model.eval()
|
74 |
-
|
75 |
-
# # Batch encode all texts in the request at once
|
76 |
-
# inputs = tokenizer(
|
77 |
-
# request.texts,
|
78 |
-
# add_special_tokens=True,
|
79 |
-
# max_length=64,
|
80 |
-
# truncation=True,
|
81 |
-
# padding='max_length',
|
82 |
-
# return_attention_mask=True,
|
83 |
-
# return_tensors='pt'
|
84 |
-
# )
|
85 |
-
|
86 |
-
# # Run batch inference
|
87 |
-
# with torch.no_grad():
|
88 |
-
# logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
89 |
-
# predictions = torch.argmax(logits, dim=1).tolist()
|
90 |
-
|
91 |
-
# # Format results
|
92 |
-
# results = [
|
93 |
-
# {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
|
94 |
-
# for idx, (text, pred) in enumerate(zip(request.texts, predictions))
|
95 |
-
# ]
|
96 |
-
|
97 |
-
# return {"results": results}
|
98 |
-
|
99 |
-
# except Exception as e:
|
100 |
-
# logging.error(f"Batch prediction failed: {e}")
|
101 |
-
# raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
|
102 |
|
103 |
|
104 |
@app.get("/")
|
|
|
3 |
import torch
|
4 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
5 |
import os
|
6 |
+
import logging
|
7 |
|
8 |
app = FastAPI()
|
9 |
|
10 |
+
# Configure logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
# Set the cache directory for Hugging Face
|
15 |
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
|
16 |
|
|
|
19 |
try:
|
20 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
21 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
+
logger.info("Model and tokenizer loaded successfully")
|
23 |
except Exception as e:
|
24 |
+
logger.error(f"Failed to load model or tokenizer: {e}")
|
25 |
raise RuntimeError(f"Failed to load model or tokenizer: {e}")
|
26 |
|
27 |
class TextRequest(BaseModel):
|
|
|
49 |
prediction = torch.argmax(logits, dim=1).item()
|
50 |
return {"prediction": "Spam" if prediction == 1 else "Ham"}
|
51 |
except Exception as e:
|
52 |
+
logger.error(f"Prediction failed: {e}")
|
53 |
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
|
54 |
|
55 |
|
|
|
57 |
async def batch_predict(request: BatchTextRequest):
|
58 |
try:
|
59 |
model.eval()
|
60 |
+
logger.info(f"Received batch prediction request for {len(request.texts)} texts")
|
61 |
+
inputs = tokenizer(
|
62 |
+
request.texts,
|
63 |
+
add_special_tokens=True,
|
64 |
+
max_length=64,
|
65 |
+
truncation=True,
|
66 |
+
padding='max_length',
|
67 |
+
return_attention_mask=True,
|
68 |
+
return_tensors='pt'
|
69 |
+
)
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
73 |
+
predictions = torch.argmax(logits, dim=1).tolist()
|
74 |
+
|
75 |
+
results = [
|
76 |
+
{"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
|
77 |
+
for idx, (text, pred) in enumerate(zip(request.texts, predictions))
|
78 |
+
]
|
79 |
+
logger.info(f"Batch prediction results: {results}")
|
80 |
return {"results": results}
|
81 |
except Exception as e:
|
82 |
+
logger.error(f"Batch prediction failed: {e}")
|
83 |
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
|
86 |
@app.get("/")
|