File size: 971 Bytes
3f6c2a7
 
 
 
 
 
 
 
2dfc2fc
dbe1f8e
 
3f6c2a7
 
 
 
 
 
dbe1f8e
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

app = FastAPI()

# Load model and tokenizer
model_name = "Bijoy09/MObilebert"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

class TextRequest(BaseModel):
    text: str

@app.post("/predict/")
async def predict(request: TextRequest):
    model.eval()
    inputs = tokenizer.encode_plus(
        request.text,
        add_special_tokens=True,
        max_length=128,
        truncation=True,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt'
    )
    with torch.no_grad():
        logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
        prediction = torch.argmax(logits, dim=1).item()
    return {"prediction": "Spam" if prediction == 1 else "Ham"}