aznasut commited on
Commit
cbc96c3
·
1 Parent(s): a43fbe2

remove tf-keras

Browse files
Files changed (2) hide show
  1. main.py +45 -30
  2. requirements.txt +0 -1
main.py CHANGED
@@ -11,7 +11,6 @@ from transformers.pipelines import PipelineException
11
  from transformers import AutoImageProcessor, ViTForImageClassification
12
  from PIL import Image
13
  from cachetools import Cache
14
- import tensorflow as tf
15
  import torch
16
  import torch.nn.functional as F
17
  from models import (
@@ -34,12 +33,12 @@ cache = Cache(maxsize=1000)
34
  # model = pipeline("image-classification", model="Wvolf/ViT_Deepfake_Detection")
35
 
36
  # Detect the device used by TensorFlow
37
- DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU"
38
- logging.info("TensorFlow version: %s", tf.__version__)
39
- logging.info("Model is using: %s", DEVICE)
40
 
41
- if DEVICE == "GPU":
42
- logging.info("GPUs available: %d", len(tf.config.list_physical_devices("GPU")))
43
 
44
 
45
  async def download_image(image_url: str) -> bytes:
@@ -85,38 +84,58 @@ async def classify_image(file: UploadFile = File(None)):
85
 
86
  image = Image.open(io.BytesIO(image_data))
87
 
88
- # Use the model to classify the image
89
- # results = model(image)
90
-
91
- image_processor = AutoImageProcessor.from_pretrained("dima806/ai_vs_real_image_detection")
92
- model = ViTForImageClassification.from_pretrained("dima806/ai_vs_real_image_detection")
93
-
94
  inputs = image_processor(image, return_tensors="pt")
95
 
96
  with torch.no_grad():
97
  logits = model(**inputs).logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # model predicts one of the 1000 ImageNet classes
100
- predicted_label = logits.argmax(-1).item()
101
- logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
102
- # print(model.config.id2label[predicted_label])
103
  # Find the prediction with the highest confidence using the max() function
104
  # best_prediction = max(results, key=lambda x: x["score"])
105
- # logging.info("best_prediction %s", best_prediction)
106
- # best_prediction2 = results[1]["label"]
107
- # logging.info("best_prediction2 %s", best_prediction2)
108
 
109
- # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
110
  # confidence_percentage = round(best_prediction["score"] * 100, 1)
111
 
112
- # # Prepare the custom response data
113
- response_data = {
114
- "prediction": model.config.id2label[predicted_label],
115
- "confidence_percentage":model.config.id2label[predicted_label],
116
- }
117
 
118
  # Populate hash
119
- cache[image_hash] = response_data.copy()
 
 
 
 
 
120
 
121
  # Add file_name to the API response
122
  response_data["file_name"] = file.filename
@@ -214,10 +233,6 @@ async def classify_images(request: ImageUrlsRequest):
214
 
215
  return JSONResponse(status_code=200, content=response_data)
216
 
217
- @app.get("/hello")
218
- async def hello_world():
219
- return {"message": "hello_world"}
220
-
221
  if __name__ == "__main__":
222
  import uvicorn
223
 
 
11
  from transformers import AutoImageProcessor, ViTForImageClassification
12
  from PIL import Image
13
  from cachetools import Cache
 
14
  import torch
15
  import torch.nn.functional as F
16
  from models import (
 
33
  # model = pipeline("image-classification", model="Wvolf/ViT_Deepfake_Detection")
34
 
35
  # Detect the device used by TensorFlow
36
+ # DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU"
37
+ # logging.info("TensorFlow version: %s", tf.__version__)
38
+ # logging.info("Model is using: %s", DEVICE)
39
 
40
+ # if DEVICE == "GPU":
41
+ # logging.info("GPUs available: %d", len(tf.config.list_physical_devices("GPU")))
42
 
43
 
44
  async def download_image(image_url: str) -> bytes:
 
84
 
85
  image = Image.open(io.BytesIO(image_data))
86
 
87
+ image_processor = AutoImageProcessor.from_pretrained("Wvolf/ViT_Deepfake_Detection")
88
+ model = ViTForImageClassification.from_pretrained("Wvolf/ViT_Deepfake_Detection")
 
 
 
 
89
  inputs = image_processor(image, return_tensors="pt")
90
 
91
  with torch.no_grad():
92
  logits = model(**inputs).logits
93
+ probs = F.softmax(logits, dim=-1)
94
+ predicted_label_id = probs.argmax(-1).item()
95
+ predicted_label = model.config.id2label[predicted_label_id]
96
+ confidence = probs.max().item()
97
+
98
+ # model predicts one of the 1000 ImageNet classes
99
+ # predicted_label = logits.argmax(-1).item()
100
+ # logging.info("predicted_label", predicted_label)
101
+ # logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
102
+ # # print(model.config.id2label[predicted_label])
103
+ # Find the prediction with the highest confidence using the max() function
104
+ # best_prediction = max(results, key=lambda x: x["score"])
105
+ # logging.info("best_prediction %s", best_prediction)
106
+ # best_prediction2 = results[1]["label"]
107
+ # logging.info("best_prediction2 %s", best_prediction2)
108
+
109
+ # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
110
+ # confidence_percentage = round(best_prediction["score"] * 100, 1)
111
+
112
+ # # Prepare the custom response data
113
+ detection_result = {
114
+ "prediction": predicted_label,
115
+ "confidence_percentage":confidence,
116
+ }
117
+ # Use the model to classify the image
118
+ # results = model(image)
119
 
 
 
 
 
120
  # Find the prediction with the highest confidence using the max() function
121
  # best_prediction = max(results, key=lambda x: x["score"])
 
 
 
122
 
123
+ # Calculate the confidence score, rounded to the nearest tenth and as a percentage
124
  # confidence_percentage = round(best_prediction["score"] * 100, 1)
125
 
126
+ # Prepare the custom response data
127
+ # detection_result = {
128
+ # "is_nsfw": best_prediction["label"] == "nsfw",
129
+ # "confidence_percentage": confidence_percentage,
130
+ # }
131
 
132
  # Populate hash
133
+ cache[image_hash] = detection_result.copy()
134
+
135
+ # Add url to the API response
136
+ detection_result["file_name"] = file.filename
137
+
138
+ response_data.append(detection_result)
139
 
140
  # Add file_name to the API response
141
  response_data["file_name"] = file.filename
 
233
 
234
  return JSONResponse(status_code=200, content=response_data)
235
 
 
 
 
 
236
  if __name__ == "__main__":
237
  import uvicorn
238
 
requirements.txt CHANGED
@@ -5,6 +5,5 @@ aiohttp==3.9.5
5
  pillow==10.3.0
6
  python-multipart==0.0.9
7
  torch
8
- tf-keras==2.16.0
9
  cachetools===5.3.3
10
  pydantic===2.7.2
 
5
  pillow==10.3.0
6
  python-multipart==0.0.9
7
  torch
 
8
  cachetools===5.3.3
9
  pydantic===2.7.2