Bijoy09 commited on
Commit
714fa8b
1 Parent(s): 46cd2b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -16
app.py CHANGED
@@ -45,28 +45,60 @@ async def predict(request: TextRequest):
45
  raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  @app.post("/batch_predict")
49
  async def batch_predict(request: BatchTextRequest):
50
  try:
51
  model.eval()
52
- results = []
53
- for idx, text in enumerate(request.texts):
54
- inputs = tokenizer.encode_plus(
55
- text,
56
- add_special_tokens=True,
57
- max_length=64,
58
- truncation=True,
59
- padding='max_length',
60
- return_attention_mask=True,
61
- return_tensors='pt'
62
- )
63
- with torch.no_grad():
64
- logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
65
- prediction = torch.argmax(logits, dim=1).item()
66
- results.append({"id": idx + 1, "text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
 
 
 
 
 
 
 
 
67
  return {"results": results}
 
68
  except Exception as e:
69
- raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
 
70
 
71
 
72
  @app.get("/")
 
45
  raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
46
 
47
 
48
+ # @app.post("/batch_predict")
49
+ # async def batch_predict(request: BatchTextRequest):
50
+ # try:
51
+ # model.eval()
52
+ # results = []
53
+ # for idx, text in enumerate(request.texts):
54
+ # inputs = tokenizer.encode_plus(
55
+ # text,
56
+ # add_special_tokens=True,
57
+ # max_length=64,
58
+ # truncation=True,
59
+ # padding='max_length',
60
+ # return_attention_mask=True,
61
+ # return_tensors='pt'
62
+ # )
63
+ # with torch.no_grad():
64
+ # logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
65
+ # prediction = torch.argmax(logits, dim=1).item()
66
+ # results.append({"id": idx + 1, "text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
67
+ # return {"results": results}
68
+ # except Exception as e:
69
+ # raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
70
  @app.post("/batch_predict")
71
  async def batch_predict(request: BatchTextRequest):
72
  try:
73
  model.eval()
74
+
75
+ # Batch encode all texts in the request at once
76
+ inputs = tokenizer(
77
+ request.texts,
78
+ add_special_tokens=True,
79
+ max_length=64,
80
+ truncation=True,
81
+ padding='max_length',
82
+ return_attention_mask=True,
83
+ return_tensors='pt'
84
+ )
85
+
86
+ # Run batch inference
87
+ with torch.no_grad():
88
+ logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
89
+ predictions = torch.argmax(logits, dim=1).tolist()
90
+
91
+ # Format results
92
+ results = [
93
+ {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
94
+ for idx, (text, pred) in enumerate(zip(request.texts, predictions))
95
+ ]
96
+
97
  return {"results": results}
98
+
99
  except Exception as e:
100
+ logging.error(f"Batch prediction failed: {e}")
101
+ raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
102
 
103
 
104
  @app.get("/")