Spaces:
Runtime error
Runtime error
polejowska
commited on
Commit
•
6978df0
1
Parent(s):
079ad0e
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,6 @@ from constants import MODELS_REPO, MODELS_NAMES
|
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
|
7 |
-
|
8 |
from transformers import (AutoFeatureExtractor, DetrForObjectDetection,)
|
9 |
from visualization import visualize_attention_map, visualize_prediction
|
10 |
from style import css, description, title
|
@@ -12,6 +11,7 @@ from style import css, description, title
|
|
12 |
from PIL import Image
|
13 |
|
14 |
|
|
|
15 |
def make_prediction(img, feature_extractor, model):
|
16 |
inputs = feature_extractor(img, return_tensors="pt")
|
17 |
outputs = model(**inputs)
|
@@ -29,8 +29,6 @@ def make_prediction(img, feature_extractor, model):
|
|
29 |
def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None):
|
30 |
feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
|
31 |
|
32 |
-
print("Set threshold to: ", threshold)
|
33 |
-
|
34 |
if "DETR" in model_name:
|
35 |
model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
|
36 |
model_details = "DETR details"
|
@@ -80,41 +78,43 @@ with gr.Blocks(css=css) as app:
|
|
80 |
with gr.TabItem("Image upload and detections visualization"):
|
81 |
with gr.Row():
|
82 |
with gr.Column():
|
83 |
-
|
|
|
84 |
img_input_mask = gr.Image(type="pil", visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
with gr.Column():
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
show_label=True,
|
103 |
-
)
|
104 |
-
slider_input = gr.Slider(
|
105 |
-
minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
|
106 |
-
)
|
107 |
-
with gr.Column():
|
108 |
-
display_mask = gr.Checkbox(
|
109 |
-
label="Display masks", default=False
|
110 |
-
)
|
111 |
-
detect_button = gr.Button("Detect leukocytes")
|
112 |
with gr.Row():
|
113 |
with gr.Column():
|
114 |
gr.Markdown(
|
115 |
"""The selected image with detected bounding boxes by the model"""
|
116 |
)
|
117 |
-
img_output_from_upload = gr.Image(shape=(
|
118 |
with gr.TabItem("Attentions visualization"):
|
119 |
gr.Markdown("""Encoder attentions""")
|
120 |
with gr.Row():
|
|
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
|
|
|
7 |
from transformers import (AutoFeatureExtractor, DetrForObjectDetection,)
|
8 |
from visualization import visualize_attention_map, visualize_prediction
|
9 |
from style import css, description, title
|
|
|
11 |
from PIL import Image
|
12 |
|
13 |
|
14 |
+
|
15 |
def make_prediction(img, feature_extractor, model):
|
16 |
inputs = feature_extractor(img, return_tensors="pt")
|
17 |
outputs = model(**inputs)
|
|
|
29 |
def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None):
|
30 |
feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
|
31 |
|
|
|
|
|
32 |
if "DETR" in model_name:
|
33 |
model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
|
34 |
model_details = "DETR details"
|
|
|
78 |
with gr.TabItem("Image upload and detections visualization"):
|
79 |
with gr.Row():
|
80 |
with gr.Column():
|
81 |
+
with gr.Row():
|
82 |
+
img_input = gr.Image(type="pil")
|
83 |
img_input_mask = gr.Image(type="pil", visible=False)
|
84 |
+
with gr.Row():
|
85 |
+
example_images = gr.Dataset(
|
86 |
+
components=[img_input, img_input_mask],
|
87 |
+
samples=[
|
88 |
+
[path.as_posix(), path.as_posix().replace("_HE", "_mask")]
|
89 |
+
for path in sorted(
|
90 |
+
pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png")
|
91 |
+
)
|
92 |
+
],
|
93 |
+
samples_per_page=2,
|
94 |
+
)
|
95 |
with gr.Column():
|
96 |
+
with gr.Row():
|
97 |
+
options = gr.Dropdown(
|
98 |
+
value=MODELS_NAMES[0],
|
99 |
+
choices=MODELS_NAMES,
|
100 |
+
label="Select an object detection model",
|
101 |
+
show_label=True,
|
102 |
+
)
|
103 |
+
slider_input = gr.Slider(
|
104 |
+
minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
|
105 |
+
)
|
106 |
+
with gr.Row():
|
107 |
+
display_mask = gr.Checkbox(
|
108 |
+
label="Display masks", default=False
|
109 |
+
)
|
110 |
+
with gr.Row():
|
111 |
+
detect_button = gr.Button("Detect leukocytes")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
with gr.Row():
|
113 |
with gr.Column():
|
114 |
gr.Markdown(
|
115 |
"""The selected image with detected bounding boxes by the model"""
|
116 |
)
|
117 |
+
img_output_from_upload = gr.Image(shape=(800, 800))
|
118 |
with gr.TabItem("Attentions visualization"):
|
119 |
gr.Markdown("""Encoder attentions""")
|
120 |
with gr.Row():
|