Bijoy09 commited on
Commit
efcb399
1 Parent(s): d2281db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -6
app.py CHANGED
@@ -53,11 +53,42 @@ async def predict(request: TextRequest):
53
  raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @app.post("/batch_predict/")
57
  async def batch_predict(request: BatchTextRequest):
58
  try:
59
  model.eval()
60
- logger.info(f"Received batch prediction request for {len(request.texts)} texts")
 
61
  inputs = tokenizer(
62
  request.texts,
63
  add_special_tokens=True,
@@ -68,20 +99,22 @@ async def batch_predict(request: BatchTextRequest):
68
  return_tensors='pt'
69
  )
70
 
 
71
  with torch.no_grad():
72
  logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
73
  predictions = torch.argmax(logits, dim=1).tolist()
74
-
 
75
  results = [
76
  {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
77
  for idx, (text, pred) in enumerate(zip(request.texts, predictions))
78
  ]
79
- logger.info(f"Batch prediction results: {results}")
80
  return {"results": results}
81
- except Exception as e:
82
- logger.error(f"Batch prediction failed: {e}")
83
- raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
84
 
 
 
 
85
 
86
  @app.get("/")
87
  async def root():
 
53
  raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
54
 
55
 
56
+ # @app.post("/batch_predict/")
57
+ # async def batch_predict(request: BatchTextRequest):
58
+ # try:
59
+ # model.eval()
60
+ # logger.info(f"Received batch prediction request for {len(request.texts)} texts")
61
+ # inputs = tokenizer(
62
+ # request.texts,
63
+ # add_special_tokens=True,
64
+ # max_length=64,
65
+ # truncation=True,
66
+ # padding='max_length',
67
+ # return_attention_mask=True,
68
+ # return_tensors='pt'
69
+ # )
70
+
71
+ # with torch.no_grad():
72
+ # logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
73
+ # predictions = torch.argmax(logits, dim=1).tolist()
74
+
75
+ # results = [
76
+ # {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
77
+ # for idx, (text, pred) in enumerate(zip(request.texts, predictions))
78
+ # ]
79
+ # logger.info(f"Batch prediction results: {results}")
80
+ # return {"results": results}
81
+ # except Exception as e:
82
+ # logger.error(f"Batch prediction failed: {e}")
83
+ # raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
84
+
85
+
86
  @app.post("/batch_predict/")
87
  async def batch_predict(request: BatchTextRequest):
88
  try:
89
  model.eval()
90
+
91
+ # Batch encode all texts in the request at once
92
  inputs = tokenizer(
93
  request.texts,
94
  add_special_tokens=True,
 
99
  return_tensors='pt'
100
  )
101
 
102
+ # Run batch inference
103
  with torch.no_grad():
104
  logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
105
  predictions = torch.argmax(logits, dim=1).tolist()
106
+
107
+ # Format results
108
  results = [
109
  {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
110
  for idx, (text, pred) in enumerate(zip(request.texts, predictions))
111
  ]
112
+
113
  return {"results": results}
 
 
 
114
 
115
+ except Exception as e:
116
+ logging.error(f"Batch prediction failed: {e}")
117
+ raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
118
 
119
  @app.get("/")
120
  async def root():