Spaces:
Runtime error
Runtime error
vincentclaes
commited on
Commit
•
5ed6ee0
1
Parent(s):
262aca1
return confidence score
Browse files
app.py
CHANGED
@@ -10,23 +10,25 @@ APP_NAME = "Mona Lisa Detection"
|
|
10 |
|
11 |
logger.debug("loading processor and model.")
|
12 |
processor = AutoFeatureExtractor.from_pretrained(
|
13 |
-
"drift-ai/autotrain-mona-lisa-detection-38345101350",
|
14 |
-
use_auth_token=True
|
15 |
)
|
16 |
model = AutoModelForImageClassification.from_pretrained(
|
17 |
-
"drift-ai/autotrain-mona-lisa-detection-38345101350",
|
18 |
-
use_auth_token=True
|
19 |
)
|
20 |
logger.debug("loading processor and model succeeded.")
|
21 |
|
22 |
|
23 |
def process_image(image, model=model, processor=processor):
|
24 |
logger.info("Making a prediction ...")
|
|
|
25 |
inputs = processor(images=image, return_tensors="pt")
|
26 |
outputs = model(**inputs)
|
27 |
logits = outputs.logits
|
28 |
predicted_class_idx = logits.argmax(-1).item()
|
29 |
-
|
|
|
|
|
|
|
30 |
print("Predicted class:", result)
|
31 |
logger.info("Prediction finished.")
|
32 |
return result
|
|
|
10 |
|
11 |
logger.debug("loading processor and model.")
|
12 |
processor = AutoFeatureExtractor.from_pretrained(
|
13 |
+
"drift-ai/autotrain-mona-lisa-detection-38345101350", use_auth_token=True
|
|
|
14 |
)
|
15 |
model = AutoModelForImageClassification.from_pretrained(
|
16 |
+
"drift-ai/autotrain-mona-lisa-detection-38345101350", use_auth_token=True
|
|
|
17 |
)
|
18 |
logger.debug("loading processor and model succeeded.")
|
19 |
|
20 |
|
21 |
def process_image(image, model=model, processor=processor):
|
22 |
logger.info("Making a prediction ...")
|
23 |
+
|
24 |
inputs = processor(images=image, return_tensors="pt")
|
25 |
outputs = model(**inputs)
|
26 |
logits = outputs.logits
|
27 |
predicted_class_idx = logits.argmax(-1).item()
|
28 |
+
|
29 |
+
label = {1: "Not Mona Lisa", 0: "Mona Lisa"}
|
30 |
+
predictions = logits.softmax(dim=-1).tolist()
|
31 |
+
result = {label[predicted_class_idx]: predictions[0][predicted_class_idx]}
|
32 |
print("Predicted class:", result)
|
33 |
logger.info("Prediction finished.")
|
34 |
return result
|