File size: 4,879 Bytes
3370ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f7eaee
3370ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f7eaee
3370ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e694a8
 
 
 
 
 
 
 
7f7eaee
 
 
3370ff8
 
 
 
 
 
 
 
 
7f7eaee
3370ff8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import pathlib
from constants import MODELS_REPO, MODELS_NAMES

import gradio as gr
import torch


from transformers import (AutoFeatureExtractor, DetrForObjectDetection,
                          YolosForObjectDetection)
from visualization import visualize_attention_map, visualize_prediction
from style import css, description, title


def make_prediction(img, feature_extractor, model):
    inputs = feature_extractor(img, return_tensors="pt")
    outputs = model(**inputs)
    img_size = torch.tensor([tuple(reversed(img.size))])
    processed_outputs = feature_extractor.post_process(outputs, img_size)
    print(outputs.keys())
    # if model type is YOLOS, then return "attentions"
    if "attentions" in outputs.keys():
        return (
            processed_outputs[0],
            outputs["attentions"],
            outputs["attentions"],
            outputs["attentions"],
        )
    return (
        processed_outputs[0],
        outputs["decoder_attentions"],
        outputs["encoder_attentions"],
        outputs["cross_attentions"],
    )


def detect_objects(model_name, image_input, threshold):
    feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])

    if "DETR" in model_name:
        model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
        model_details = "DETR details"
    elif "YOLOS" in model_name:
        model = YolosForObjectDetection.from_pretrained(MODELS_REPO[model_name])

    (
        processed_outputs,
        decoder_attention_map,
        encoder_attention_map,
        cross_attention_map,
    ) = make_prediction(image_input, feature_extractor, model)

    viz_img = visualize_prediction(
        image_input, processed_outputs, threshold, model.config.id2label
    )
    decoder_attention_map_img = visualize_attention_map(
        image_input, decoder_attention_map
    )
    encoder_attention_map_img = visualize_attention_map(
        image_input, encoder_attention_map
    )
    cross_attention_map_img = visualize_attention_map(image_input, cross_attention_map)

    return (
        viz_img,
        decoder_attention_map_img,
        encoder_attention_map_img,
        cross_attention_map_img,
        model_details
    )


def set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])


with gr.Blocks(css=css) as app:
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Tabs():
        with gr.TabItem("Image upload and detections visualization"):
            with gr.Row():
                with gr.Column():
                    img_input = gr.Image(type="pil")
                with gr.Column():
                    options = gr.Dropdown(
                        value=MODELS_NAMES[0],
                        choices=MODELS_NAMES,
                        label="Select an object detection model",
                        show_label=True,
                    )
                    slider_input = gr.Slider(
                        minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
                    )
                    detect_button = gr.Button("Detect leukocytes")
            with gr.Row():
                example_images = gr.Dataset(
                    components=[img_input],
                    samples=[
                        [path.as_posix()]
                        for path in sorted(
                            pathlib.Path("cd45rb_test_imgs").rglob("*.png")
                        )
                    ],
                )
            with gr.Row():
                with gr.Column():
                    gr.Markdown(
                        """The selected image with detected bounding boxes by the model"""
                    )
                    img_output_from_upload = gr.Image(shape=(850, 850))
        with gr.TabItem("Attention maps visualization"):
            with gr.Row():
                gr.Markdown("""Encoder attentions""")
                encoder_att_map_output = gr.Image(shape=(850, 850))
            with gr.Row():
                gr.Markdown("""Decoder attentions""")
                decoder_att_map_output = gr.Image(shape=(850, 850))
            with gr.Row():
                gr.Markdown("""Cross attentions""")
                cross_att_map_output = gr.Image(shape=(850, 850))
        with gr.TabItem("Model details"):
            with gr.Row():
                model_details = gr.Markdown(""" """)

    detect_button.click(
        detect_objects,
        inputs=[options, img_input, slider_input],
        outputs=[
            img_output_from_upload,
            decoder_att_map_output,
            encoder_att_map_output,
            cross_att_map_output,
            model_details,
        ],
        queue=True,
    )
    example_images.click(
        fn=set_example_image, inputs=[example_images], outputs=[img_input]
    )

    app.launch(enable_queue=True)