File size: 5,133 Bytes
3370ff8
 
 
 
 
 
 
7cc09fd
3370ff8
 
 
074be6a
 
3370ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a67e15
3370ff8
 
8a67e15
 
3370ff8
 
7f7eaee
3370ff8
 
 
 
 
 
 
 
074be6a
8a67e15
074be6a
3370ff8
8a67e15
 
 
 
 
 
3370ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
7f7eaee
3370ff8
 
 
f5a0872
cb1656e
3370ff8
 
 
 
 
 
 
 
 
 
 
9abbc70
3370ff8
 
 
 
 
 
 
074be6a
 
 
 
 
 
 
 
 
3370ff8
 
 
9abbc70
3370ff8
9abbc70
3370ff8
074be6a
3370ff8
 
 
 
 
 
 
 
 
074be6a
340429d
3370ff8
1e694a8
340429d
1e694a8
 
340429d
1e694a8
 
7f7eaee
 
 
3370ff8
 
 
8a67e15
3370ff8
 
 
 
 
7f7eaee
3370ff8
 
 
 
cb1656e
3370ff8
 
9abbc70
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import pathlib
from constants import MODELS_REPO, MODELS_NAMES

import gradio as gr
import torch


from transformers import (AutoFeatureExtractor, DetrForObjectDetection,)
from visualization import visualize_attention_map, visualize_prediction
from style import css, description, title

from PIL import Image


def make_prediction(img, feature_extractor, model):
    inputs = feature_extractor(img, return_tensors="pt")
    outputs = model(**inputs)
    img_size = torch.tensor([tuple(reversed(img.size))])
    processed_outputs = feature_extractor.post_process(outputs, img_size)
    print(outputs.keys())
    return (
        processed_outputs[0],
        outputs["decoder_attentions"],
        outputs["encoder_attentions"],
        outputs["cross_attentions"],
    )


def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None):
    feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])

    print("Set threshold to: ", threshold)

    if "DETR" in model_name:
        model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
        model_details = "DETR details"

    (
        processed_outputs,
        decoder_attention_map,
        encoder_attention_map,
        cross_attention_map,
    ) = make_prediction(image_input, feature_extractor, model)

    if display_mask:
        mask_pil_image = img_input_mask

    viz_img = visualize_prediction(
        image_input,
        processed_outputs,
        threshold,
        model.config.id2label,
        display_mask,
        mask_pil_image
    )
    decoder_attention_map_img = visualize_attention_map(
        image_input, decoder_attention_map
    )
    encoder_attention_map_img = visualize_attention_map(
        image_input, encoder_attention_map
    )
    cross_attention_map_img = visualize_attention_map(image_input, cross_attention_map)

    return (
        viz_img,
        decoder_attention_map_img,
        encoder_attention_map_img,
        cross_attention_map_img,
        model_details
    )


def set_example_image(example: list):
    return gr.Image.update(value=example[0])


with gr.Blocks(css=css) as app:
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Tabs():
        with gr.TabItem("Image upload and detections visualization"):
            with gr.Row():
                with gr.Column():
                    img_input = gr.Image(type="pil")
                    img_input_mask = gr.Image(type="pil", visible=False)
                with gr.Column():
                    options = gr.Dropdown(
                        value=MODELS_NAMES[0],
                        choices=MODELS_NAMES,
                        label="Select an object detection model",
                        show_label=True,
                    )
                    with gr.Row():
                        with gr.Column():
                            slider_input = gr.Slider(
                                minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
                            )
                        with gr.Column():
                            display_mask = gr.Checkbox(
                                label="Display masks", default=False
                            )
                    detect_button = gr.Button("Detect leukocytes")
            with gr.Row():
                example_images = gr.Dataset(
                    components=[img_input, img_input_mask],
                    samples=[
                        [path.as_posix(), path.as_posix().replace("_HE", "_mask")]
                        for path in sorted(
                            pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png")
                        )
                    ],
                )
            with gr.Row():
                with gr.Column():
                    gr.Markdown(
                        """The selected image with detected bounding boxes by the model"""
                    )
                    img_output_from_upload = gr.Image(shape=(850, 850))
        with gr.TabItem("Attentions visualization"):
            gr.Markdown("""Encoder attentions""")
            with gr.Row():
                encoder_att_map_output = gr.Image(shape=(850, 850))
            gr.Markdown("""Decoder attentions""")
            with gr.Row():
                decoder_att_map_output = gr.Image(shape=(850, 850))
            gr.Markdown("""Cross attentions""")
            with gr.Row():
                cross_att_map_output = gr.Image(shape=(850, 850))
        with gr.TabItem("Model details"):
            with gr.Row():
                model_details = gr.Markdown(""" """)

    detect_button.click(
        detect_objects,
        inputs=[options, img_input, img_input_mask, slider_input, display_mask],
        outputs=[
            img_output_from_upload,
            decoder_att_map_output,
            encoder_att_map_output,
            cross_att_map_output,
            model_details,
        ],
        queue=True,
    )
    example_images.click(
        fn=set_example_image, inputs=[example_images], outputs=[img_input]
    )

    app.launch(enable_queue=True)