vincentclaes commited on
Commit
5ed6ee0
1 Parent(s): 262aca1

return confidence score

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -10,23 +10,25 @@ APP_NAME = "Mona Lisa Detection"
10
 
11
  logger.debug("loading processor and model.")
12
  processor = AutoFeatureExtractor.from_pretrained(
13
- "drift-ai/autotrain-mona-lisa-detection-38345101350",
14
- use_auth_token=True
15
  )
16
  model = AutoModelForImageClassification.from_pretrained(
17
- "drift-ai/autotrain-mona-lisa-detection-38345101350",
18
- use_auth_token=True
19
  )
20
  logger.debug("loading processor and model succeeded.")
21
 
22
 
23
  def process_image(image, model=model, processor=processor):
24
  logger.info("Making a prediction ...")
 
25
  inputs = processor(images=image, return_tensors="pt")
26
  outputs = model(**inputs)
27
  logits = outputs.logits
28
  predicted_class_idx = logits.argmax(-1).item()
29
- result = model.config.id2label[predicted_class_idx]
 
 
 
30
  print("Predicted class:", result)
31
  logger.info("Prediction finished.")
32
  return result
 
10
 
11
  logger.debug("loading processor and model.")
12
  processor = AutoFeatureExtractor.from_pretrained(
13
+ "drift-ai/autotrain-mona-lisa-detection-38345101350", use_auth_token=True
 
14
  )
15
  model = AutoModelForImageClassification.from_pretrained(
16
+ "drift-ai/autotrain-mona-lisa-detection-38345101350", use_auth_token=True
 
17
  )
18
  logger.debug("loading processor and model succeeded.")
19
 
20
 
21
  def process_image(image, model=model, processor=processor):
22
  logger.info("Making a prediction ...")
23
+
24
  inputs = processor(images=image, return_tensors="pt")
25
  outputs = model(**inputs)
26
  logits = outputs.logits
27
  predicted_class_idx = logits.argmax(-1).item()
28
+
29
+ label = {1: "Not Mona Lisa", 0: "Mona Lisa"}
30
+ predictions = logits.softmax(dim=-1).tolist()
31
+ result = {label[predicted_class_idx]: predictions[0][predicted_class_idx]}
32
  print("Predicted class:", result)
33
  logger.info("Prediction finished.")
34
  return result