polejowska commited on
Commit
074be6a
1 Parent(s): d52d15a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -16
app.py CHANGED
@@ -9,6 +9,8 @@ from transformers import (AutoFeatureExtractor, DetrForObjectDetection,)
9
  from visualization import visualize_attention_map, visualize_prediction
10
  from style import css, description, title
11
 
 
 
12
 
13
  def make_prediction(img, feature_extractor, model):
14
  inputs = feature_extractor(img, return_tensors="pt")
@@ -16,14 +18,6 @@ def make_prediction(img, feature_extractor, model):
16
  img_size = torch.tensor([tuple(reversed(img.size))])
17
  processed_outputs = feature_extractor.post_process(outputs, img_size)
18
  print(outputs.keys())
19
- # if model type is YOLOS, then return "attentions"
20
- if "attentions" in outputs.keys():
21
- return (
22
- processed_outputs[0],
23
- outputs["attentions"],
24
- outputs["attentions"],
25
- outputs["attentions"],
26
- )
27
  return (
28
  processed_outputs[0],
29
  outputs["decoder_attentions"],
@@ -32,7 +26,7 @@ def make_prediction(img, feature_extractor, model):
32
  )
33
 
34
 
35
- def detect_objects(model_name, image_input, threshold):
36
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
37
 
38
  if "DETR" in model_name:
@@ -46,8 +40,17 @@ def detect_objects(model_name, image_input, threshold):
46
  cross_attention_map,
47
  ) = make_prediction(image_input, feature_extractor, model)
48
 
 
 
 
 
 
 
 
 
 
49
  viz_img = visualize_prediction(
50
- image_input, processed_outputs, threshold, model.config.id2label
51
  )
52
  decoder_attention_map_img = visualize_attention_map(
53
  image_input, decoder_attention_map
@@ -86,9 +89,15 @@ with gr.Blocks(css=css) as app:
86
  label="Select an object detection model",
87
  show_label=True,
88
  )
89
- slider_input = gr.Slider(
90
- minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
91
- )
 
 
 
 
 
 
92
  detect_button = gr.Button("Detect leukocytes")
93
  with gr.Row():
94
  example_images = gr.Dataset(
@@ -96,7 +105,7 @@ with gr.Blocks(css=css) as app:
96
  samples=[
97
  [path.as_posix()]
98
  for path in sorted(
99
- pathlib.Path("cd45rb_test_imgs").rglob("*.png")
100
  )
101
  ],
102
  )
@@ -106,7 +115,7 @@ with gr.Blocks(css=css) as app:
106
  """The selected image with detected bounding boxes by the model"""
107
  )
108
  img_output_from_upload = gr.Image(shape=(850, 850))
109
- with gr.TabItem("Attention maps visualization"):
110
  gr.Markdown("""Encoder attentions""")
111
  with gr.Row():
112
  encoder_att_map_output = gr.Image(shape=(850, 850))
@@ -122,7 +131,7 @@ with gr.Blocks(css=css) as app:
122
 
123
  detect_button.click(
124
  detect_objects,
125
- inputs=[options, img_input, slider_input],
126
  outputs=[
127
  img_output_from_upload,
128
  decoder_att_map_output,
@@ -137,3 +146,4 @@ with gr.Blocks(css=css) as app:
137
  )
138
 
139
  app.launch(enable_queue=True)
 
 
9
  from visualization import visualize_attention_map, visualize_prediction
10
  from style import css, description, title
11
 
12
+ from PIL import Image
13
+
14
 
15
  def make_prediction(img, feature_extractor, model):
16
  inputs = feature_extractor(img, return_tensors="pt")
 
18
  img_size = torch.tensor([tuple(reversed(img.size))])
19
  processed_outputs = feature_extractor.post_process(outputs, img_size)
20
  print(outputs.keys())
 
 
 
 
 
 
 
 
21
  return (
22
  processed_outputs[0],
23
  outputs["decoder_attentions"],
 
26
  )
27
 
28
 
29
+ def detect_objects(model_name, image_input, threshold, display_mask=False):
30
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
31
 
32
  if "DETR" in model_name:
 
40
  cross_attention_map,
41
  ) = make_prediction(image_input, feature_extractor, model)
42
 
43
+ mask_pil_image = None
44
+ if display_mask:
45
+ # get image path
46
+ image_path = pathlib.Path(image_input.name)
47
+ mask_path = image_path.parent / (
48
+ image_path.stem.replace("_HE", "_mask") + image_path.suffix
49
+ )
50
+ mask_pil_image = Image.open(mask_path)
51
+
52
  viz_img = visualize_prediction(
53
+ image_input, processed_outputs, threshold, model.config.id2label, display_mask, mask_pil_image
54
  )
55
  decoder_attention_map_img = visualize_attention_map(
56
  image_input, decoder_attention_map
 
89
  label="Select an object detection model",
90
  show_label=True,
91
  )
92
+ with gr.Row():
93
+ with gr.Column():
94
+ slider_input = gr.Slider(
95
+ minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
96
+ )
97
+ with gr.Column():
98
+ display_mask = gr.Checkbox(
99
+ label="Display masks", default=False
100
+ )
101
  detect_button = gr.Button("Detect leukocytes")
102
  with gr.Row():
103
  example_images = gr.Dataset(
 
105
  samples=[
106
  [path.as_posix()]
107
  for path in sorted(
108
+ pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png")
109
  )
110
  ],
111
  )
 
115
  """The selected image with detected bounding boxes by the model"""
116
  )
117
  img_output_from_upload = gr.Image(shape=(850, 850))
118
+ with gr.TabItem("Attentions visualization"):
119
  gr.Markdown("""Encoder attentions""")
120
  with gr.Row():
121
  encoder_att_map_output = gr.Image(shape=(850, 850))
 
131
 
132
  detect_button.click(
133
  detect_objects,
134
+ inputs=[options, img_input, slider_input, display_mask],
135
  outputs=[
136
  img_output_from_upload,
137
  decoder_att_map_output,
 
146
  )
147
 
148
  app.launch(enable_queue=True)
149
+