Bijoy09 commited on
Commit
cd6d890
1 Parent(s): d8c29e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -17
app.py CHANGED
@@ -2,30 +2,44 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import torch
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
5
 
6
  app = FastAPI()
7
 
 
 
 
8
  # Load model and tokenizer
9
- model_name = "Bijoy09/MObilebert"
10
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
12
 
13
  class TextRequest(BaseModel):
14
  text: str
15
 
16
  @app.post("/predict/")
17
  async def predict(request: TextRequest):
18
- model.eval()
19
- inputs = tokenizer.encode_plus(
20
- request.text,
21
- add_special_tokens=True,
22
- max_length=128,
23
- truncation=True,
24
- padding='max_length',
25
- return_attention_mask=True,
26
- return_tensors='pt'
27
- )
28
- with torch.no_grad():
29
- logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
30
- prediction = torch.argmax(logits, dim=1).item()
31
- return {"prediction": "Spam" if prediction == 1 else "Ham"}
 
 
 
 
 
 
 
 
2
  from pydantic import BaseModel
3
  import torch
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+ import os
6
 
7
  app = FastAPI()
8
 
9
+ # Set the cache directory for Hugging Face
10
+ os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
11
+
12
  # Load model and tokenizer
13
+ model_name = "Bijoy09/Bangla_spam_sms_detection_app"
14
+ try:
15
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ except Exception as e:
18
+ raise RuntimeError(f"Failed to load model or tokenizer: {e}")
19
 
20
  class TextRequest(BaseModel):
21
  text: str
22
 
23
  @app.post("/predict/")
24
  async def predict(request: TextRequest):
25
+ try:
26
+ model.eval()
27
+ inputs = tokenizer.encode_plus(
28
+ request.text,
29
+ add_special_tokens=True,
30
+ max_length=64,
31
+ truncation=True,
32
+ padding='max_length',
33
+ return_attention_mask=True,
34
+ return_tensors='pt'
35
+ )
36
+ with torch.no_grad():
37
+ logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
38
+ prediction = torch.argmax(logits, dim=1).item()
39
+ return {"prediction": "Spam" if prediction == 1 else "Ham"}
40
+ except Exception as e:
41
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
42
+
43
+ @app.get("/")
44
+ async def root():
45
+ return {"message": "Welcome to the MobileBERT API"}