Bijoy09 commited on
Commit
6632a1b
1 Parent(s): 3adf7f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -21
app.py CHANGED
@@ -1,47 +1,32 @@
1
  from fastapi import FastAPI, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
6
  import os
7
- import logging
8
 
9
  app = FastAPI()
10
 
11
- # Configure logging
12
- logging.basicConfig(level=logging.INFO)
13
- logger = logging.getLogger(__name__)
14
-
15
  # Set the cache directory for Hugging Face
16
  os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
17
 
18
- # Enable CORS
19
- app.add_middleware(
20
- CORSMiddleware,
21
- allow_origins=["*"],
22
- allow_credentials=True,
23
- allow_methods=["*"],
24
- allow_headers=["*"],
25
- )
26
-
27
  # Load model and tokenizer
28
  model_name = "Bijoy09/MObilebert"
29
  try:
30
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
31
  tokenizer = AutoTokenizer.from_pretrained(model_name)
32
- logger.info("Model and tokenizer loaded successfully")
33
  except Exception as e:
34
- logger.error(f"Failed to load model or tokenizer: {e}")
35
  raise RuntimeError(f"Failed to load model or tokenizer: {e}")
36
 
37
  class TextRequest(BaseModel):
38
  text: str
39
 
 
 
 
40
  @app.post("/predict/")
41
  @app.post("/predict")
42
  async def predict(request: TextRequest):
43
  try:
44
- logger.info(f"Received text: {request.text}")
45
  model.eval()
46
  inputs = tokenizer.encode_plus(
47
  request.text,
@@ -52,16 +37,37 @@ async def predict(request: TextRequest):
52
  return_attention_mask=True,
53
  return_tensors='pt'
54
  )
55
- logger.info(f"Tokenized inputs: {inputs}")
56
  with torch.no_grad():
57
  logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
58
- logger.info(f"Model logits: {logits}")
59
  prediction = torch.argmax(logits, dim=1).item()
60
  return {"prediction": "Spam" if prediction == 1 else "Ham"}
61
  except Exception as e:
62
- logger.error(f"Prediction failed: {e}")
63
  raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @app.get("/")
66
  async def root():
67
  return {"message": "Welcome to the MobileBERT API"}
 
1
  from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
3
  import torch
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  import os
 
6
 
7
  app = FastAPI()
8
 
 
 
 
 
9
  # Set the cache directory for Hugging Face
10
  os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', '/app/cache')
11
 
 
 
 
 
 
 
 
 
 
12
  # Load model and tokenizer
13
  model_name = "Bijoy09/MObilebert"
14
  try:
15
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
17
  except Exception as e:
 
18
  raise RuntimeError(f"Failed to load model or tokenizer: {e}")
19
 
20
  class TextRequest(BaseModel):
21
  text: str
22
 
23
+ class BatchTextRequest(BaseModel):
24
+ texts: list[str]
25
+
26
  @app.post("/predict/")
27
  @app.post("/predict")
28
  async def predict(request: TextRequest):
29
  try:
 
30
  model.eval()
31
  inputs = tokenizer.encode_plus(
32
  request.text,
 
37
  return_attention_mask=True,
38
  return_tensors='pt'
39
  )
 
40
  with torch.no_grad():
41
  logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
 
42
  prediction = torch.argmax(logits, dim=1).item()
43
  return {"prediction": "Spam" if prediction == 1 else "Ham"}
44
  except Exception as e:
 
45
  raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
46
 
47
+ @app.post("/batch_predict/")
48
+ @app.post("/batch_predict")
49
+ async def batch_predict(request: BatchTextRequest):
50
+ try:
51
+ model.eval()
52
+ results = []
53
+ for text in request.texts:
54
+ inputs = tokenizer.encode_plus(
55
+ text,
56
+ add_special_tokens=True,
57
+ max_length=64,
58
+ truncation=True,
59
+ padding='max_length',
60
+ return_attention_mask=True,
61
+ return_tensors='pt'
62
+ )
63
+ with torch.no_grad():
64
+ logits = model(inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
65
+ prediction = torch.argmax(logits, dim=1).item()
66
+ results.append({"text": text, "prediction": "Spam" if prediction == 1 else "Ham"})
67
+ return {"results": results}
68
+ except Exception as e:
69
+ raise HTTPException(status_code=500, detail=f"Batch prediction failed: {e}")
70
+
71
  @app.get("/")
72
  async def root():
73
  return {"message": "Welcome to the MobileBERT API"}