polejowska commited on
Commit
6978df0
1 Parent(s): 079ad0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -31
app.py CHANGED
@@ -4,7 +4,6 @@ from constants import MODELS_REPO, MODELS_NAMES
4
  import gradio as gr
5
  import torch
6
 
7
-
8
  from transformers import (AutoFeatureExtractor, DetrForObjectDetection,)
9
  from visualization import visualize_attention_map, visualize_prediction
10
  from style import css, description, title
@@ -12,6 +11,7 @@ from style import css, description, title
12
  from PIL import Image
13
 
14
 
 
15
  def make_prediction(img, feature_extractor, model):
16
  inputs = feature_extractor(img, return_tensors="pt")
17
  outputs = model(**inputs)
@@ -29,8 +29,6 @@ def make_prediction(img, feature_extractor, model):
29
  def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None):
30
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
31
 
32
- print("Set threshold to: ", threshold)
33
-
34
  if "DETR" in model_name:
35
  model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
36
  model_details = "DETR details"
@@ -80,41 +78,43 @@ with gr.Blocks(css=css) as app:
80
  with gr.TabItem("Image upload and detections visualization"):
81
  with gr.Row():
82
  with gr.Column():
83
- img_input = gr.Image(type="pil")
 
84
  img_input_mask = gr.Image(type="pil", visible=False)
 
 
 
 
 
 
 
 
 
 
 
85
  with gr.Column():
86
- example_images = gr.Dataset(
87
- components=[img_input, img_input_mask],
88
- samples=[
89
- [path.as_posix(), path.as_posix().replace("_HE", "_mask")]
90
- for path in sorted(
91
- pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png")
92
- )
93
- ],
94
- samples_per_page=2,
95
- )
96
- with gr.Row():
97
- with gr.Column():
98
- options = gr.Dropdown(
99
- value=MODELS_NAMES[0],
100
- choices=MODELS_NAMES,
101
- label="Select an object detection model",
102
- show_label=True,
103
- )
104
- slider_input = gr.Slider(
105
- minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
106
- )
107
- with gr.Column():
108
- display_mask = gr.Checkbox(
109
- label="Display masks", default=False
110
- )
111
- detect_button = gr.Button("Detect leukocytes")
112
  with gr.Row():
113
  with gr.Column():
114
  gr.Markdown(
115
  """The selected image with detected bounding boxes by the model"""
116
  )
117
- img_output_from_upload = gr.Image(shape=(850, 850))
118
  with gr.TabItem("Attentions visualization"):
119
  gr.Markdown("""Encoder attentions""")
120
  with gr.Row():
 
4
  import gradio as gr
5
  import torch
6
 
 
7
  from transformers import (AutoFeatureExtractor, DetrForObjectDetection,)
8
  from visualization import visualize_attention_map, visualize_prediction
9
  from style import css, description, title
 
11
  from PIL import Image
12
 
13
 
14
+
15
  def make_prediction(img, feature_extractor, model):
16
  inputs = feature_extractor(img, return_tensors="pt")
17
  outputs = model(**inputs)
 
29
  def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None):
30
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
31
 
 
 
32
  if "DETR" in model_name:
33
  model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
34
  model_details = "DETR details"
 
78
  with gr.TabItem("Image upload and detections visualization"):
79
  with gr.Row():
80
  with gr.Column():
81
+ with gr.Row():
82
+ img_input = gr.Image(type="pil")
83
  img_input_mask = gr.Image(type="pil", visible=False)
84
+ with gr.Row():
85
+ example_images = gr.Dataset(
86
+ components=[img_input, img_input_mask],
87
+ samples=[
88
+ [path.as_posix(), path.as_posix().replace("_HE", "_mask")]
89
+ for path in sorted(
90
+ pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png")
91
+ )
92
+ ],
93
+ samples_per_page=2,
94
+ )
95
  with gr.Column():
96
+ with gr.Row():
97
+ options = gr.Dropdown(
98
+ value=MODELS_NAMES[0],
99
+ choices=MODELS_NAMES,
100
+ label="Select an object detection model",
101
+ show_label=True,
102
+ )
103
+ slider_input = gr.Slider(
104
+ minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
105
+ )
106
+ with gr.Row():
107
+ display_mask = gr.Checkbox(
108
+ label="Display masks", default=False
109
+ )
110
+ with gr.Row():
111
+ detect_button = gr.Button("Detect leukocytes")
 
 
 
 
 
 
 
 
 
 
112
  with gr.Row():
113
  with gr.Column():
114
  gr.Markdown(
115
  """The selected image with detected bounding boxes by the model"""
116
  )
117
+ img_output_from_upload = gr.Image(shape=(800, 800))
118
  with gr.TabItem("Attentions visualization"):
119
  gr.Markdown("""Encoder attentions""")
120
  with gr.Row():