Update app.py
Browse files
app.py
CHANGED
@@ -31,7 +31,7 @@ async def predict(request: TextRequest):
|
|
31 |
inputs = tokenizer.encode_plus(
|
32 |
request.text,
|
33 |
add_special_tokens=True,
|
34 |
-
max_length=
|
35 |
truncation=True,
|
36 |
padding='max_length',
|
37 |
return_attention_mask=True,
|
@@ -50,11 +50,11 @@ async def batch_predict(request: BatchTextRequest):
|
|
50 |
try:
|
51 |
model.eval()
|
52 |
results = []
|
53 |
-
for text in request.texts:
|
54 |
inputs = tokenizer.encode_plus(
|
55 |
text,
|
56 |
add_special_tokens=True,
|
57 |
-
max_length=
|
58 |
truncation=True,
|
59 |
padding='max_length',
|
60 |
return_attention_mask=True,
|
@@ -63,11 +63,12 @@ async def batch_predict(request: BatchTextRequest):
|
|
63 |
with torch.no_grad():
|
64 |
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
65 |
prediction = torch.argmax(logits, dim=1).item()
|
66 |
-
results.append({"text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
|
67 |
return {"results": results}
|
68 |
except Exception as e:
|
69 |
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
|
70 |
|
|
|
71 |
@app.get("/")
|
72 |
async def root():
|
73 |
return {"message": "Welcome to the MobileBERT API"}
|
|
|
31 |
inputs = tokenizer.encode_plus(
|
32 |
request.text,
|
33 |
add_special_tokens=True,
|
34 |
+
max_length=128,
|
35 |
truncation=True,
|
36 |
padding='max_length',
|
37 |
return_attention_mask=True,
|
|
|
50 |
try:
|
51 |
model.eval()
|
52 |
results = []
|
53 |
+
for idx, text in enumerate(request.texts):
|
54 |
inputs = tokenizer.encode_plus(
|
55 |
text,
|
56 |
add_special_tokens=True,
|
57 |
+
max_length=128,
|
58 |
truncation=True,
|
59 |
padding='max_length',
|
60 |
return_attention_mask=True,
|
|
|
63 |
with torch.no_grad():
|
64 |
logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
65 |
prediction = torch.argmax(logits, dim=1).item()
|
66 |
+
results.append({"id": idx + 1, "text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
|
67 |
return {"results": results}
|
68 |
except Exception as e:
|
69 |
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
|
70 |
|
71 |
+
|
72 |
@app.get("/")
|
73 |
async def root():
|
74 |
return {"message": "Welcome to the MobileBERT API"}
|