Update app.py
Browse files
app.py
CHANGED
@@ -53,11 +53,42 @@ async def predict(request: TextRequest):
|
|
53 |
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
@app.post("/batch_predict/")
|
57 |
async def batch_predict(request: BatchTextRequest):
|
58 |
try:
|
59 |
model.eval()
|
60 |
-
|
|
|
61 |
inputs = tokenizer(
|
62 |
request.texts,
|
63 |
add_special_tokens=True,
|
@@ -68,20 +99,22 @@ async def batch_predict(request: BatchTextRequest):
|
|
68 |
return_tensors='pt'
|
69 |
)
|
70 |
|
|
|
71 |
with torch.no_grad():
|
72 |
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
73 |
predictions = torch.argmax(logits, dim=1).tolist()
|
74 |
-
|
|
|
75 |
results = [
|
76 |
{"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
|
77 |
for idx, (text, pred) in enumerate(zip(request.texts, predictions))
|
78 |
]
|
79 |
-
|
80 |
return {"results": results}
|
81 |
-
except Exception as e:
|
82 |
-
logger.error(f"Batch prediction failed: {e}")
|
83 |
-
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
|
84 |
|
|
|
|
|
|
|
85 |
|
86 |
@app.get("/")
|
87 |
async def root():
|
|
|
53 |
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
|
54 |
|
55 |
|
56 |
+
# @app.post("/batch_predict/")
|
57 |
+
# async def batch_predict(request: BatchTextRequest):
|
58 |
+
# try:
|
59 |
+
# model.eval()
|
60 |
+
# logger.info(f"Received batch prediction request for {len(request.texts)} texts")
|
61 |
+
# inputs = tokenizer(
|
62 |
+
# request.texts,
|
63 |
+
# add_special_tokens=True,
|
64 |
+
# max_length=64,
|
65 |
+
# truncation=True,
|
66 |
+
# padding='max_length',
|
67 |
+
# return_attention_mask=True,
|
68 |
+
# return_tensors='pt'
|
69 |
+
# )
|
70 |
+
|
71 |
+
# with torch.no_grad():
|
72 |
+
# logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
73 |
+
# predictions = torch.argmax(logits, dim=1).tolist()
|
74 |
+
|
75 |
+
# results = [
|
76 |
+
# {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
|
77 |
+
# for idx, (text, pred) in enumerate(zip(request.texts, predictions))
|
78 |
+
# ]
|
79 |
+
# logger.info(f"Batch prediction results: {results}")
|
80 |
+
# return {"results": results}
|
81 |
+
# except Exception as e:
|
82 |
+
# logger.error(f"Batch prediction failed: {e}")
|
83 |
+
# raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
|
84 |
+
|
85 |
+
|
86 |
@app.post("/batch_predict/")
|
87 |
async def batch_predict(request: BatchTextRequest):
|
88 |
try:
|
89 |
model.eval()
|
90 |
+
|
91 |
+
# Batch encode all texts in the request at once
|
92 |
inputs = tokenizer(
|
93 |
request.texts,
|
94 |
add_special_tokens=True,
|
|
|
99 |
return_tensors='pt'
|
100 |
)
|
101 |
|
102 |
+
# Run batch inference
|
103 |
with torch.no_grad():
|
104 |
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
105 |
predictions = torch.argmax(logits, dim=1).tolist()
|
106 |
+
|
107 |
+
# Format results
|
108 |
results = [
|
109 |
{"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
|
110 |
for idx, (text, pred) in enumerate(zip(request.texts, predictions))
|
111 |
]
|
112 |
+
|
113 |
return {"results": results}
|
|
|
|
|
|
|
114 |
|
115 |
+
except Exception as e:
|
116 |
+
logging.error(f"Batch prediction failed: {e}")
|
117 |
+
raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
|
118 |
|
119 |
@app.get("/")
|
120 |
async def root():
|