Bijoy09 commited on
Commit
3f6c2a7
1 Parent(s): c7fcd30

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+
6
+ app = FastAPI()
7
+
8
+ # Load model and tokenizer
9
+ model_name = "Bijoy09/your_mobilebert_model_repo" # replace with your Hugging Face repo name
10
+ try:
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ except Exception as e:
14
+ raise RuntimeError(f"Failed to load model or tokenizer: {e}")
15
+
16
+ class TextRequest(BaseModel):
17
+ text: str
18
+
19
+ @app.post("/predict/")
20
+ async def predict(request: TextRequest):
21
+ try:
22
+ model.eval()
23
+ inputs = tokenizer.encode_plus(
24
+ request.text,
25
+ add_special_tokens=True,
26
+ max_length=64,
27
+ truncation=True,
28
+ padding='max_length',
29
+ return_attention_mask=True,
30
+ return_tensors='pt'
31
+ )
32
+ with torch.no_grad():
33
+ logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
34
+ prediction = torch.argmax(logits, dim=1).item()
35
+ return {"prediction": "Spam" if prediction == 1 else "Ham"}
36
+ except Exception as e:
37
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
38
+
39
+ @app.get("/")
40
+ async def root():
41
+ return {"message": "Welcome to the MobileBERT API"}