polejowska commited on
Commit
8a67e15
1 Parent(s): 9abbc70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -26,9 +26,11 @@ def make_prediction(img, feature_extractor, model):
26
  )
27
 
28
 
29
- def detect_objects(model_name, image_input, image_path, threshold, display_mask=False):
30
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
31
 
 
 
32
  if "DETR" in model_name:
33
  model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
34
  model_details = "DETR details"
@@ -40,15 +42,16 @@ def detect_objects(model_name, image_input, image_path, threshold, display_mask=
40
  cross_attention_map,
41
  ) = make_prediction(image_input, feature_extractor, model)
42
 
43
- mask_pil_image = None
44
  if display_mask:
45
- mask_path = image_path.parent / (
46
- image_path.stem.replace("_HE", "_mask") + image_path.suffix
47
- )
48
- mask_pil_image = Image.open(mask_path)
49
 
50
  viz_img = visualize_prediction(
51
- image_input, processed_outputs, threshold, model.config.id2label, display_mask, mask_pil_image
 
 
 
 
 
52
  )
53
  decoder_attention_map_img = visualize_attention_map(
54
  image_input, decoder_attention_map
@@ -130,7 +133,7 @@ with gr.Blocks(css=css) as app:
130
 
131
  detect_button.click(
132
  detect_objects,
133
- inputs=[options, img_input, slider_input, display_mask],
134
  outputs=[
135
  img_output_from_upload,
136
  decoder_att_map_output,
 
26
  )
27
 
28
 
29
+ def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None):
30
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
31
 
32
+ print("Set threshold to: ", threshold)
33
+
34
  if "DETR" in model_name:
35
  model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
36
  model_details = "DETR details"
 
42
  cross_attention_map,
43
  ) = make_prediction(image_input, feature_extractor, model)
44
 
 
45
  if display_mask:
46
+ mask_pil_image = img_input_mask
 
 
 
47
 
48
  viz_img = visualize_prediction(
49
+ image_input,
50
+ processed_outputs,
51
+ threshold,
52
+ model.config.id2label,
53
+ display_mask,
54
+ mask_pil_image
55
  )
56
  decoder_attention_map_img = visualize_attention_map(
57
  image_input, decoder_attention_map
 
133
 
134
  detect_button.click(
135
  detect_objects,
136
+ inputs=[options, img_input, img_input_mask, slider_input, display_mask],
137
  outputs=[
138
  img_output_from_upload,
139
  decoder_att_map_output,