File size: 5,128 Bytes
3370ff8
 
 
 
 
 
 
7cc09fd
3370ff8
 
 
074be6a
 
3370ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5a0872
3370ff8
 
 
 
7f7eaee
3370ff8
 
 
 
 
 
 
 
074be6a
 
 
 
 
 
 
3370ff8
074be6a
3370ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
7f7eaee
3370ff8
 
 
f5a0872
 
 
3370ff8
 
 
 
 
1882a3e
3370ff8
 
 
 
 
 
 
 
 
 
 
 
 
074be6a
 
 
 
 
 
 
 
 
3370ff8
 
 
 
 
 
 
074be6a
3370ff8
 
 
 
 
 
 
 
 
074be6a
340429d
3370ff8
1e694a8
340429d
1e694a8
 
340429d
1e694a8
 
7f7eaee
 
 
3370ff8
 
 
f5a0872
3370ff8
 
 
 
 
7f7eaee
3370ff8
 
 
 
f5a0872
3370ff8
 
 
1882a3e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import pathlib
from constants import MODELS_REPO, MODELS_NAMES

import gradio as gr
import torch


from transformers import (AutoFeatureExtractor, DetrForObjectDetection,)
from visualization import visualize_attention_map, visualize_prediction
from style import css, description, title

from PIL import Image


def make_prediction(img, feature_extractor, model):
    inputs = feature_extractor(img, return_tensors="pt")
    outputs = model(**inputs)
    img_size = torch.tensor([tuple(reversed(img.size))])
    processed_outputs = feature_extractor.post_process(outputs, img_size)
    print(outputs.keys())
    return (
        processed_outputs[0],
        outputs["decoder_attentions"],
        outputs["encoder_attentions"],
        outputs["cross_attentions"],
    )


def detect_objects(model_name, image_input, image_path, threshold, display_mask=False):
    feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])

    if "DETR" in model_name:
        model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
        model_details = "DETR details"

    (
        processed_outputs,
        decoder_attention_map,
        encoder_attention_map,
        cross_attention_map,
    ) = make_prediction(image_input, feature_extractor, model)

    mask_pil_image = None
    if display_mask:
        mask_path = image_path.parent / (
            image_path.stem.replace("_HE", "_mask") + image_path.suffix
        )
        mask_pil_image = Image.open(mask_path)

    viz_img = visualize_prediction(
        image_input, processed_outputs, threshold, model.config.id2label, display_mask, mask_pil_image
    )
    decoder_attention_map_img = visualize_attention_map(
        image_input, decoder_attention_map
    )
    encoder_attention_map_img = visualize_attention_map(
        image_input, encoder_attention_map
    )
    cross_attention_map_img = visualize_attention_map(image_input, cross_attention_map)

    return (
        viz_img,
        decoder_attention_map_img,
        encoder_attention_map_img,
        cross_attention_map_img,
        model_details
    )


def set_example_image(example: list):
    img_path = example[0]
    return gr.Image.update(value=example[0]), img_path


with gr.Blocks(css=css) as app:
    gr.Markdown(title)
    gr.Markdown(description)
    img_path = None

    with gr.Tabs():
        with gr.TabItem("Image upload and detections visualization"):
            with gr.Row():
                with gr.Column():
                    img_input = gr.Image(type="pil")
                with gr.Column():
                    options = gr.Dropdown(
                        value=MODELS_NAMES[0],
                        choices=MODELS_NAMES,
                        label="Select an object detection model",
                        show_label=True,
                    )
                    with gr.Row():
                        with gr.Column():
                            slider_input = gr.Slider(
                                minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
                            )
                        with gr.Column():
                            display_mask = gr.Checkbox(
                                label="Display masks", default=False
                            )
                    detect_button = gr.Button("Detect leukocytes")
            with gr.Row():
                example_images = gr.Dataset(
                    components=[img_input],
                    samples=[
                        [path.as_posix()]
                        for path in sorted(
                            pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png")
                        )
                    ],
                )
            with gr.Row():
                with gr.Column():
                    gr.Markdown(
                        """The selected image with detected bounding boxes by the model"""
                    )
                    img_output_from_upload = gr.Image(shape=(850, 850))
        with gr.TabItem("Attentions visualization"):
            gr.Markdown("""Encoder attentions""")
            with gr.Row():
                encoder_att_map_output = gr.Image(shape=(850, 850))
            gr.Markdown("""Decoder attentions""")
            with gr.Row():
                decoder_att_map_output = gr.Image(shape=(850, 850))
            gr.Markdown("""Cross attentions""")
            with gr.Row():
                cross_att_map_output = gr.Image(shape=(850, 850))
        with gr.TabItem("Model details"):
            with gr.Row():
                model_details = gr.Markdown(""" """)

    detect_button.click(
        detect_objects,
        inputs=[options, img_input, img_path, slider_input, display_mask],
        outputs=[
            img_output_from_upload,
            decoder_att_map_output,
            encoder_att_map_output,
            cross_att_map_output,
            model_details,
        ],
        queue=True,
    )
    example_images.click(
        fn=set_example_image, inputs=[example_images], outputs=[img_input, img_path]
    )

    app.launch(enable_queue=True)