File size: 5,433 Bytes
3370ff8
 
 
 
 
 
 
7cc09fd
3370ff8
 
 
074be6a
 
3370ff8
 
 
 
 
d24bef6
3370ff8
 
 
 
 
 
 
 
 
8a67e15
3370ff8
 
8a67e15
 
3370ff8
 
7f7eaee
3370ff8
 
 
 
 
 
 
 
 
7923a1c
 
 
 
 
2a99234
3370ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
7f7eaee
3370ff8
 
 
f5a0872
12dc72b
 
4b6f2a1
3370ff8
 
 
 
 
 
 
 
 
 
b0c1c95
3370ff8
 
 
 
 
 
 
074be6a
 
 
 
 
 
 
 
 
3370ff8
 
 
9abbc70
3370ff8
9abbc70
3370ff8
074be6a
3370ff8
 
d24bef6
3370ff8
 
 
 
 
 
 
074be6a
340429d
3370ff8
1e694a8
340429d
1e694a8
 
340429d
1e694a8
 
7f7eaee
 
 
8c96c1c
 
 
3370ff8
 
 
0592643
3370ff8
 
 
 
 
7f7eaee
3370ff8
 
 
 
12dc72b
 
3370ff8
 
9abbc70
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import pathlib
from constants import MODELS_REPO, MODELS_NAMES

import gradio as gr
import torch


from transformers import (AutoFeatureExtractor, DetrForObjectDetection,)
from visualization import visualize_attention_map, visualize_prediction
from style import css, description, title

from PIL import Image


def make_prediction(img, feature_extractor, model):
    inputs = feature_extractor(img, return_tensors="pt")
    outputs = model(**inputs)
    img_size = torch.tensor([tuple(reversed(img.size))])
    processed_outputs = feature_extractor.post_process(outputs, img_size)
    print(outputs.keys())
    return (
        processed_outputs[0],
        outputs["decoder_attentions"],
        outputs["encoder_attentions"],
        outputs["cross_attentions"],
    )


def detect_objects(model_name, image_input, threshold, display_mask=False, img_input_mask=None):
    feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])

    print("Set threshold to: ", threshold)

    if "DETR" in model_name:
        model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
        model_details = "DETR details"

    (
        processed_outputs,
        decoder_attention_map,
        encoder_attention_map,
        cross_attention_map,
    ) = make_prediction(image_input, feature_extractor, model)

    viz_img = visualize_prediction(
        pil_img=image_input,
        output_dict=processed_outputs,
        threshold=threshold,
        id2label=model.config.id2label,
        display_mask=display_mask,
        mask=img_input_mask
    )
    decoder_attention_map_img = visualize_attention_map(
        image_input, decoder_attention_map
    )
    encoder_attention_map_img = visualize_attention_map(
        image_input, encoder_attention_map
    )
    cross_attention_map_img = visualize_attention_map(image_input, cross_attention_map)

    return (
        viz_img,
        decoder_attention_map_img,
        encoder_attention_map_img,
        cross_attention_map_img,
        model_details
    )


def set_example_image(example: list):
    print(f"Set example image to: {example[0]}")
    print(f"Set example image mask to: {example[1]}")
    return gr.Image.update(value=example[0]), gr.Image.update(value=example[1])


with gr.Blocks(css=css) as app:
    gr.Markdown(title)

    with gr.Tabs():
        with gr.TabItem("Image upload and detections visualization"):
            with gr.Row():
                with gr.Column():
                    img_input = gr.Image(type="pil")
                    img_input_mask = gr.Image(type="pil", visible=False)
                with gr.Column():
                    options = gr.Dropdown(
                        value=MODELS_NAMES[0],
                        choices=MODELS_NAMES,
                        label="Select an object detection model",
                        show_label=True,
                    )
                    with gr.Row():
                        with gr.Column():
                            slider_input = gr.Slider(
                                minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
                            )
                        with gr.Column():
                            display_mask = gr.Checkbox(
                                label="Display masks", default=False
                            )
                    detect_button = gr.Button("Detect leukocytes")
            with gr.Row():
                example_images = gr.Dataset(
                    components=[img_input, img_input_mask],
                    samples=[
                        [path.as_posix(), path.as_posix().replace("_HE", "_mask")]
                        for path in sorted(
                            pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png")
                        )
                    ],
                    samples_per_page=2,
                )
            with gr.Row():
                with gr.Column():
                    gr.Markdown(
                        """The selected image with detected bounding boxes by the model"""
                    )
                    img_output_from_upload = gr.Image(shape=(850, 850))
        with gr.TabItem("Attentions visualization"):
            gr.Markdown("""Encoder attentions""")
            with gr.Row():
                encoder_att_map_output = gr.Image(shape=(850, 850))
            gr.Markdown("""Decoder attentions""")
            with gr.Row():
                decoder_att_map_output = gr.Image(shape=(850, 850))
            gr.Markdown("""Cross attentions""")
            with gr.Row():
                cross_att_map_output = gr.Image(shape=(850, 850))
        with gr.TabItem("Model details"):
            with gr.Row():
                model_details = gr.Markdown(""" """)
        with gr.TabItem("Dataset details"):
            with gr.Row():
                gr.Markdown(description)

    detect_button.click(
        detect_objects,
        inputs=[options, img_input, slider_input, display_mask, img_input_mask],
        outputs=[
            img_output_from_upload,
            decoder_att_map_output,
            encoder_att_map_output,
            cross_att_map_output,
            model_details,
        ],
        queue=True,
    )
    example_images.click(
        fn=set_example_image, inputs=[example_images], outputs=[img_input, img_input_mask],
        show_progress=True
    )

    app.launch(enable_queue=True)