Spaces:
Runtime error
Runtime error
polejowska
commited on
Commit
•
8a67e15
1
Parent(s):
9abbc70
Update app.py
Browse files
app.py
CHANGED
@@ -26,9 +26,11 @@ def make_prediction(img, feature_extractor, model):
|
|
26 |
)
|
27 |
|
28 |
|
29 |
-
def detect_objects(model_name, image_input,
|
30 |
feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
|
31 |
|
|
|
|
|
32 |
if "DETR" in model_name:
|
33 |
model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
|
34 |
model_details = "DETR details"
|
@@ -40,15 +42,16 @@ def detect_objects(model_name, image_input, image_path, threshold, display_mask=
|
|
40 |
cross_attention_map,
|
41 |
) = make_prediction(image_input, feature_extractor, model)
|
42 |
|
43 |
-
mask_pil_image = None
|
44 |
if display_mask:
|
45 |
-
|
46 |
-
image_path.stem.replace("_HE", "_mask") + image_path.suffix
|
47 |
-
)
|
48 |
-
mask_pil_image = Image.open(mask_path)
|
49 |
|
50 |
viz_img = visualize_prediction(
|
51 |
-
image_input,
|
|
|
|
|
|
|
|
|
|
|
52 |
)
|
53 |
decoder_attention_map_img = visualize_attention_map(
|
54 |
image_input, decoder_attention_map
|
@@ -130,7 +133,7 @@ with gr.Blocks(css=css) as app:
|
|
130 |
|
131 |
detect_button.click(
|
132 |
detect_objects,
|
133 |
-
inputs=[options, img_input, slider_input, display_mask],
|
134 |
outputs=[
|
135 |
img_output_from_upload,
|
136 |
decoder_att_map_output,
|
|
|
26 |
)
|
27 |
|
28 |
|
29 |
+
def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None):
|
30 |
feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
|
31 |
|
32 |
+
print("Set threshold to: ", threshold)
|
33 |
+
|
34 |
if "DETR" in model_name:
|
35 |
model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
|
36 |
model_details = "DETR details"
|
|
|
42 |
cross_attention_map,
|
43 |
) = make_prediction(image_input, feature_extractor, model)
|
44 |
|
|
|
45 |
if display_mask:
|
46 |
+
mask_pil_image = img_input_mask
|
|
|
|
|
|
|
47 |
|
48 |
viz_img = visualize_prediction(
|
49 |
+
image_input,
|
50 |
+
processed_outputs,
|
51 |
+
threshold,
|
52 |
+
model.config.id2label,
|
53 |
+
display_mask,
|
54 |
+
mask_pil_image
|
55 |
)
|
56 |
decoder_attention_map_img = visualize_attention_map(
|
57 |
image_input, decoder_attention_map
|
|
|
133 |
|
134 |
detect_button.click(
|
135 |
detect_objects,
|
136 |
+
inputs=[options, img_input, img_input_mask, slider_input, display_mask],
|
137 |
outputs=[
|
138 |
img_output_from_upload,
|
139 |
decoder_att_map_output,
|