Bijoy09 commited on
Commit
0d35220
1 Parent(s): eb64215

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -86
app.py CHANGED
@@ -3,17 +3,18 @@ from pydantic import BaseModel
3
  import torch
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  import os
 
6
  import logging
7
-
8
  app = FastAPI()
9
-
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
-
14
  # Set the cache directory for Hugging Face
15
  os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
16
-
17
  # Load model and tokenizer
18
  model_name = "Bijoy09/MObilebert"
19
  try:
@@ -23,99 +24,58 @@ try:
23
  except Exception as e:
24
  logger.error(f"Failed to load model or tokenizer: {e}")
25
  raise RuntimeError(f"Failed to load model or tokenizer: {e}")
26
-
27
  class TextRequest(BaseModel):
28
  text: str
29
-
30
  class BatchTextRequest(BaseModel):
31
  texts: list[str]
32
-
33
-
34
- @app.post("/predict/")
35
- async def predict(request: TextRequest):
36
- try:
37
- model.eval()
38
- inputs = tokenizer.encode_plus(
39
- request.text,
40
- add_special_tokens=True,
41
- max_length=64,
42
- truncation=True,
43
- padding='max_length',
44
- return_attention_mask=True,
45
- return_tensors='pt'
46
- )
47
- with torch.no_grad():
48
- logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
49
- prediction = torch.argmax(logits, dim=1).item()
50
- return {"prediction": "Spam" if prediction == 1 else "Ham"}
51
- except Exception as e:
52
- logger.error(f"Prediction failed: {e}")
53
- raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
54
-
55
-
56
- # @app.post("/batch_predict/")
57
- # async def batch_predict(request: BatchTextRequest):
58
- # try:
59
- # model.eval()
60
- # logger.info(f"Received batch prediction request for {len(request.texts)} texts")
61
- # inputs = tokenizer(
62
- # request.texts,
63
- # add_special_tokens=True,
64
- # max_length=64,
65
- # truncation=True,
66
- # padding='max_length',
67
- # return_attention_mask=True,
68
- # return_tensors='pt'
69
- # )
70
-
71
- # with torch.no_grad():
72
- # logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
73
- # predictions = torch.argmax(logits, dim=1).tolist()
74
-
75
- # results = [
76
- # {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
77
- # for idx, (text, pred) in enumerate(zip(request.texts, predictions))
78
- # ]
79
- # logger.info(f"Batch prediction results: {results}")
80
- # return {"results": results}
81
- # except Exception as e:
82
- # logger.error(f"Batch prediction failed: {e}")
83
- # raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
84
-
85
-
86
  @app.post("/batch_predict/")
87
  async def batch_predict(request: BatchTextRequest):
88
  try:
89
  model.eval()
90
-
91
- # Batch encode all texts in the request at once
92
- inputs = tokenizer(
93
- request.texts,
94
- add_special_tokens=True,
95
- max_length=64,
96
- truncation=True,
97
- padding='max_length',
98
- return_attention_mask=True,
99
- return_tensors='pt'
100
- )
101
-
102
- # Run batch inference
103
- with torch.no_grad():
104
- logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
105
- predictions = torch.argmax(logits, dim=1).tolist()
106
-
107
- # Format results
108
- results = [
109
- {"id": idx + 1, "text": text, "prediction": "Spam" if pred == 1 else "Ham"}
110
- for idx, (text, pred) in enumerate(zip(request.texts, predictions))
111
- ]
 
 
 
 
 
 
112
  logger.info(f"Batch prediction results: {results}")
113
  return {"results": results}
114
-
115
  except Exception as e:
116
- logging.error(f"Batch prediction failed: {e}")
117
  raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
118
-
119
  @app.get("/")
120
  async def root():
121
- return {"message": "Welcome to the MobileBERT API"}
 
3
  import torch
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  import os
6
+ import re
7
  import logging
8
+
9
  app = FastAPI()
10
+
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
+
15
  # Set the cache directory for Hugging Face
16
  os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
17
+
18
  # Load model and tokenizer
19
  model_name = "Bijoy09/MObilebert"
20
  try:
 
24
  except Exception as e:
25
  logger.error(f"Failed to load model or tokenizer: {e}")
26
  raise RuntimeError(f"Failed to load model or tokenizer: {e}")
27
+
28
  class TextRequest(BaseModel):
29
  text: str
30
+
31
  class BatchTextRequest(BaseModel):
32
  texts: list[str]
33
+
34
+ # Regular expression to detect Bangla characters
35
+ bangla_regex = re.compile('[\u0980-\u09FF]')
36
+
37
+ def contains_bangla(text):
38
+ return bool(bangla_regex.search(text))
39
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  @app.post("/batch_predict/")
41
  async def batch_predict(request: BatchTextRequest):
42
  try:
43
  model.eval()
44
+
45
+ # Prepare the batch results
46
+ results = []
47
+
48
+ for idx, text in enumerate(request.texts):
49
+ logger.info(f" texts: {text}")
50
+ # Check if text contains Bangla characters
51
+ if not contains_bangla(text):
52
+ results.append({"id": idx + 1, "text": text, "prediction": "other"})
53
+ continue
54
+
55
+ # Encode and predict for texts containing Bangla characters
56
+ inputs = tokenizer.encode_plus(
57
+ text,
58
+ add_special_tokens=True,
59
+ max_length=64,
60
+ truncation=True,
61
+ padding='max_length',
62
+ return_attention_mask=True,
63
+ return_tensors='pt'
64
+ )
65
+
66
+ with torch.no_grad():
67
+ logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
68
+ prediction = torch.argmax(logits, dim=1).item()
69
+ label = "Spam" if prediction == 1 else "Ham"
70
+ results.append({"id": idx + 1, "text": text, "prediction": label})
71
+
72
  logger.info(f"Batch prediction results: {results}")
73
  return {"results": results}
74
+
75
  except Exception as e:
76
+ logger.error(f"Batch prediction failed: {e}")
77
  raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
78
+
79
  @app.get("/")
80
  async def root():
81
+ return {"message": "Welcome to the MobileBERT API"}