Bijoy09 commited on
Commit
473261b
1 Parent(s): 2688a7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -47
app.py CHANGED
@@ -3,9 +3,14 @@ from pydantic import BaseModel
3
  import torch
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  import os
 
6
 
7
  app = FastAPI()
8
 
 
 
 
 
9
  # Set the cache directory for Hugging Face
10
  os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
11
 
@@ -14,7 +19,9 @@ model_name = "Bijoy09/MObilebert"
14
  try:
15
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
17
  except Exception as e:
 
18
  raise RuntimeError(f"Failed to load model or tokenizer: {e}")
19
 
20
  class TextRequest(BaseModel):
@@ -42,6 +49,7 @@ async def predict(request: TextRequest):
42
  prediction = torch.argmax(logits, dim=1).item()
43
  return {"prediction": "Spam" if prediction == 1 else "Ham"}
44
  except Exception as e:
 
45
  raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
46
 
47
 
@@ -49,56 +57,30 @@ async def predict(request: TextRequest):
49
  async def batch_predict(request: BatchTextRequest):
50
  try:
51
  model.eval()
52
- results = []
53
- for idx, text in enumerate(request.texts):
54
- inputs = tokenizer.encode_plus(
55
- text,
56
- add_special_tokens=True,
57
- max_length=64,
58
- truncation=True,
59
- padding='max_length',
60
- return_attention_mask=True,
61
- return_tensors='pt'
62
- )
63
- with torch.no_grad():
64
- logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
65
- prediction = torch.argmax(logits, dim=1).item()
66
- results.append({"id": idx + 1, "text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
 
 
 
 
 
67
  return {"results": results}
68
  except Exception as e:
 
69
  raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
70
- # @app.post("/batch_predict")
71
- # async def batch_predict(request: BatchTextRequest):
72
- # try:
73
- # model.eval()
74
-
75
- # # Batch encode all texts in the request at once
76
- # inputs = tokenizer(
77
- # request.texts,
78
- # add_special_tokens=True,
79
- # max_length=64,
80
- # truncation=True,
81
- # padding='max_length',
82
- # return_attention_mask=True,
83
- # return_tensors='pt'
84
- # )
85
-
86
- # # Run batch inference
87
- # with torch.no_grad():
88
- # logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
89
- # predictions = torch.argmax(logits, dim=1).tolist()
90
-
91
- # # Format results
92
- # results = [
93
- # {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
94
- # for idx, (text, pred) in enumerate(zip(request.texts, predictions))
95
- # ]
96
-
97
- # return {"results": results}
98
-
99
- # except Exception as e:
100
- # logging.error(f"Batch prediction failed: {e}")
101
- # raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
102
 
103
 
104
  @app.get("/")
 
3
  import torch
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  import os
6
+ import logging
7
 
8
  app = FastAPI()
9
 
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
  # Set the cache directory for Hugging Face
15
  os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
16
 
 
19
  try:
20
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ logger.info("Model and tokenizer loaded successfully")
23
  except Exception as e:
24
+ logger.error(f"Failed to load model or tokenizer: {e}")
25
  raise RuntimeError(f"Failed to load model or tokenizer: {e}")
26
 
27
  class TextRequest(BaseModel):
 
49
  prediction = torch.argmax(logits, dim=1).item()
50
  return {"prediction": "Spam" if prediction == 1 else "Ham"}
51
  except Exception as e:
52
+ logger.error(f"Prediction failed: {e}")
53
  raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
54
 
55
 
 
57
  async def batch_predict(request: BatchTextRequest):
58
  try:
59
  model.eval()
60
+ logger.info(f"Received batch prediction request for {len(request.texts)} texts")
61
+ inputs = tokenizer(
62
+ request.texts,
63
+ add_special_tokens=True,
64
+ max_length=64,
65
+ truncation=True,
66
+ padding='max_length',
67
+ return_attention_mask=True,
68
+ return_tensors='pt'
69
+ )
70
+
71
+ with torch.no_grad():
72
+ logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
73
+ predictions = torch.argmax(logits, dim=1).tolist()
74
+
75
+ results = [
76
+ {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
77
+ for idx, (text, pred) in enumerate(zip(request.texts, predictions))
78
+ ]
79
+ logger.info(f"Batch prediction results: {results}")
80
  return {"results": results}
81
  except Exception as e:
82
+ logger.error(f"Batch prediction failed: {e}")
83
  raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
 
86
  @app.get("/")