Bijoy09 commited on
Commit
2688a7f
1 Parent(s): 714fa8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -45
app.py CHANGED
@@ -45,60 +45,60 @@ async def predict(request: TextRequest):
45
  raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
46
 
47
 
48
- # @app.post("/batch_predict")
49
- # async def batch_predict(request: BatchTextRequest):
50
- # try:
51
- # model.eval()
52
- # results = []
53
- # for idx, text in enumerate(request.texts):
54
- # inputs = tokenizer.encode_plus(
55
- # text,
56
- # add_special_tokens=True,
57
- # max_length=64,
58
- # truncation=True,
59
- # padding='max_length',
60
- # return_attention_mask=True,
61
- # return_tensors='pt'
62
- # )
63
- # with torch.no_grad():
64
- # logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
65
- # prediction = torch.argmax(logits, dim=1).item()
66
- # results.append({"id": idx + 1, "text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
67
- # return {"results": results}
68
- # except Exception as e:
69
- # raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
70
  @app.post("/batch_predict")
71
  async def batch_predict(request: BatchTextRequest):
72
  try:
73
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Batch encode all texts in the request at once
76
- inputs = tokenizer(
77
- request.texts,
78
- add_special_tokens=True,
79
- max_length=64,
80
- truncation=True,
81
- padding='max_length',
82
- return_attention_mask=True,
83
- return_tensors='pt'
84
- )
85
 
86
- # Run batch inference
87
- with torch.no_grad():
88
- logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
89
- predictions = torch.argmax(logits, dim=1).tolist()
90
 
91
- # Format results
92
- results = [
93
- {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
94
- for idx, (text, pred) in enumerate(zip(request.texts, predictions))
95
- ]
96
 
97
- return {"results": results}
98
 
99
- except Exception as e:
100
- logging.error(f"Batch prediction failed: {e}")
101
- raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
102
 
103
 
104
  @app.get("/")
 
45
  raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  @app.post("/batch_predict")
49
  async def batch_predict(request: BatchTextRequest):
50
  try:
51
  model.eval()
52
+ results = []
53
+ for idx, text in enumerate(request.texts):
54
+ inputs = tokenizer.encode_plus(
55
+ text,
56
+ add_special_tokens=True,
57
+ max_length=64,
58
+ truncation=True,
59
+ padding='max_length',
60
+ return_attention_mask=True,
61
+ return_tensors='pt'
62
+ )
63
+ with torch.no_grad():
64
+ logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
65
+ prediction = torch.argmax(logits, dim=1).item()
66
+ results.append({"id": idx + 1, "text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
67
+ return {"results": results}
68
+ except Exception as e:
69
+ raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
70
+ # @app.post("/batch_predict")
71
+ # async def batch_predict(request: BatchTextRequest):
72
+ # try:
73
+ # model.eval()
74
 
75
+ # # Batch encode all texts in the request at once
76
+ # inputs = tokenizer(
77
+ # request.texts,
78
+ # add_special_tokens=True,
79
+ # max_length=64,
80
+ # truncation=True,
81
+ # padding='max_length',
82
+ # return_attention_mask=True,
83
+ # return_tensors='pt'
84
+ # )
85
 
86
+ # # Run batch inference
87
+ # with torch.no_grad():
88
+ # logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
89
+ # predictions = torch.argmax(logits, dim=1).tolist()
90
 
91
+ # # Format results
92
+ # results = [
93
+ # {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
94
+ # for idx, (text, pred) in enumerate(zip(request.texts, predictions))
95
+ # ]
96
 
97
+ # return {"results": results}
98
 
99
+ # except Exception as e:
100
+ # logging.error(f"Batch prediction failed: {e}")
101
+ # raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
102
 
103
 
104
  @app.get("/")