Bijoy09 commited on
Commit
41034fa
1 Parent(s): 6632a1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -31,7 +31,7 @@ async def predict(request: TextRequest):
31
  inputs = tokenizer.encode_plus(
32
  request.text,
33
  add_special_tokens=True,
34
- max_length=64,
35
  truncation=True,
36
  padding='max_length',
37
  return_attention_mask=True,
@@ -50,11 +50,11 @@ async def batch_predict(request: BatchTextRequest):
50
  try:
51
  model.eval()
52
  results = []
53
- for text in request.texts:
54
  inputs = tokenizer.encode_plus(
55
  text,
56
  add_special_tokens=True,
57
- max_length=64,
58
  truncation=True,
59
  padding='max_length',
60
  return_attention_mask=True,
@@ -63,11 +63,12 @@ async def batch_predict(request: BatchTextRequest):
63
  with torch.no_grad():
64
  logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
65
  prediction = torch.argmax(logits, dim=1).item()
66
- results.append({"text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
67
  return {"results": results}
68
  except Exception as e:
69
  raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
70
 
 
71
  @app.get("/")
72
  async def root():
73
  return {"message": "Welcome to the MobileBERT API"}
 
31
  inputs = tokenizer.encode_plus(
32
  request.text,
33
  add_special_tokens=True,
34
+ max_length=128,
35
  truncation=True,
36
  padding='max_length',
37
  return_attention_mask=True,
 
50
  try:
51
  model.eval()
52
  results = []
53
+ for idx, text in enumerate(request.texts):
54
  inputs = tokenizer.encode_plus(
55
  text,
56
  add_special_tokens=True,
57
+ max_length=128,
58
  truncation=True,
59
  padding='max_length',
60
  return_attention_mask=True,
 
63
  with torch.no_grad():
64
  logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
65
  prediction = torch.argmax(logits, dim=1).item()
66
+ results.append({"id": idx + 1, "text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
67
  return {"results": results}
68
  except Exception as e:
69
  raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
70
 
71
+
72
  @app.get("/")
73
  async def root():
74
  return {"message": "Welcome to the MobileBERT API"}