Bijoy09 commited on
Commit
1e28e60
1 Parent(s): d1b5757

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -18
app.py CHANGED
@@ -6,16 +6,16 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
6
  import os
7
  import re
8
  import logging
9
-
10
  app = FastAPI()
11
-
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
-
16
  # Set the cache directory for Hugging Face
17
  os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
18
-
19
  # Load model and tokenizer
20
  model_name = "Bijoy09/MObilebert"
21
  try:
@@ -25,37 +25,45 @@ try:
25
  except Exception as e:
26
  logger.error(f"Failed to load model or tokenizer: {e}")
27
  raise RuntimeError(f"Failed to load model or tokenizer: {e}")
28
-
29
  class TextRequest(BaseModel):
30
  text: str
31
-
32
  class BatchTextRequest(BaseModel):
33
  texts: list[str]
34
-
35
  # Regular expression to detect Bangla characters
36
  bangla_regex = re.compile('[\u0980-\u09FF]')
37
-
38
  def contains_bangla(text):
39
  return bool(bangla_regex.search(text))
40
-
 
 
 
41
  @app.post("/batch_predict/")
42
  async def batch_predict(request: BatchTextRequest):
43
  try:
44
  model.eval()
45
-
46
  # Prepare the batch results
47
  results = []
48
-
49
  for idx, text in enumerate(request.texts):
50
  logger.info(f" texts: {text}")
 
51
  # Check if text contains Bangla characters
52
  if not contains_bangla(text):
53
  results.append({"id": idx + 1, "text": text, "prediction": "other"})
54
  continue
55
-
 
 
 
 
56
  # Encode and predict for texts containing Bangla characters
57
  inputs = tokenizer.encode_plus(
58
- text,
59
  add_special_tokens=True,
60
  max_length=64,
61
  truncation=True,
@@ -63,20 +71,20 @@ async def batch_predict(request: BatchTextRequest):
63
  return_attention_mask=True,
64
  return_tensors='pt'
65
  )
66
-
67
  with torch.no_grad():
68
  logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
69
  prediction = torch.argmax(logits, dim=1).item()
70
  label = "Spam" if prediction == 1 else "Ham"
71
  results.append({"id": idx + 1, "text": text, "prediction": label})
72
-
73
  logger.info(f"Batch prediction results: {results}")
74
  return JSONResponse(content={"results": results}, media_type="application/json; charset=utf-8")
75
-
76
  except Exception as e:
77
  logger.error(f"Batch prediction failed: {e}")
78
  raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
79
-
80
  @app.get("/")
81
  async def root():
82
  return {"message": "Welcome to the MobileBERT API"}
 
6
  import os
7
  import re
8
  import logging
9
+
10
  app = FastAPI()
11
+
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(_name_)
15
+
16
  # Set the cache directory for Hugging Face
17
  os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
18
+
19
  # Load model and tokenizer
20
  model_name = "Bijoy09/MObilebert"
21
  try:
 
25
  except Exception as e:
26
  logger.error(f"Failed to load model or tokenizer: {e}")
27
  raise RuntimeError(f"Failed to load model or tokenizer: {e}")
28
+
29
  class TextRequest(BaseModel):
30
  text: str
31
+
32
  class BatchTextRequest(BaseModel):
33
  texts: list[str]
34
+
35
  # Regular expression to detect Bangla characters
36
  bangla_regex = re.compile('[\u0980-\u09FF]')
37
+
38
  def contains_bangla(text):
39
  return bool(bangla_regex.search(text))
40
+
41
+ def remove_non_bangla(text):
42
+ return ''.join(bangla_regex.findall(text))
43
+
44
  @app.post("/batch_predict/")
45
  async def batch_predict(request: BatchTextRequest):
46
  try:
47
  model.eval()
48
+
49
  # Prepare the batch results
50
  results = []
51
+
52
  for idx, text in enumerate(request.texts):
53
  logger.info(f" texts: {text}")
54
+
55
  # Check if text contains Bangla characters
56
  if not contains_bangla(text):
57
  results.append({"id": idx + 1, "text": text, "prediction": "other"})
58
  continue
59
+
60
+ # Remove non-Bangla characters
61
+ modified_text = remove_non_bangla(text)
62
+ ogger.info(f"modified text: {modified_text}")
63
+
64
  # Encode and predict for texts containing Bangla characters
65
  inputs = tokenizer.encode_plus(
66
+ modified_text,
67
  add_special_tokens=True,
68
  max_length=64,
69
  truncation=True,
 
71
  return_attention_mask=True,
72
  return_tensors='pt'
73
  )
74
+
75
  with torch.no_grad():
76
  logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
77
  prediction = torch.argmax(logits, dim=1).item()
78
  label = "Spam" if prediction == 1 else "Ham"
79
  results.append({"id": idx + 1, "text": text, "prediction": label})
80
+
81
  logger.info(f"Batch prediction results: {results}")
82
  return JSONResponse(content={"results": results}, media_type="application/json; charset=utf-8")
83
+
84
  except Exception as e:
85
  logger.error(f"Batch prediction failed: {e}")
86
  raise HTTPException(status_code=500, detail="Batch prediction failed. Please try again.")
87
+
88
  @app.get("/")
89
  async def root():
90
  return {"message": "Welcome to the MobileBERT API"}