File size: 3,351 Bytes
a726e1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

import torch
from transformers import pipeline

from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.patches as patches

from random import choice
import io

detector50 = pipeline(model="facebook/detr-resnet-50")

detector101 = pipeline(model="facebook/detr-resnet-101")


import gradio as gr

COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
            "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
            "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]

fdic = {
    "family" : "Impact",
    "style" : "italic",
    "size" : 15,
    "color" : "yellow",
    "weight" : "bold"
}


def get_figure(in_pil_img, in_results):
    plt.figure(figsize=(16, 10))
    plt.imshow(in_pil_img)
    #pyplot.gcf()
    ax = plt.gca()

    for prediction in in_results:
        selected_color = choice(COLORS)

        x, y = prediction['box']['xmin'], prediction['box']['ymin'],
        w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']

        ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
        ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)

    plt.axis("off")

    return plt.gcf()


def infer(model, in_pil_img):

    results = None
    if model == "detr-resnet-101":
        results = detector101(in_pil_img)
    else:
        results = detector50(in_pil_img)

    figure = get_figure(in_pil_img, results)

    buf = io.BytesIO()
    figure.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    output_pil_img = Image.open(buf)

    return output_pil_img


with gr.Blocks(title="DETR Object Detection - ClassCat",
                    css=".gradio-container {background:lightyellow;}"
               ) as demo:
    #sample_index = gr.State([])

    gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">DETR Object Detection</div>""")

    gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")

    model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">2-a. Select an example by clicking a thumbnail below.</h4>""")
    gr.HTML("""<h4 style="color:navy;">2-b. Or upload an image by clicking on the canvas.</h4>""")

    with gr.Row():
        input_image = gr.Image(label="Input image", type="pil")
        output_image = gr.Image(label="Output image with predicted instances", type="pil")

    gr.Examples(['samples/cats.jpg', 'samples/detectron2.png', 'samples/cat.jpg', 'samples/hotdog.jpg'], inputs=input_image)

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">3. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")

    send_btn = gr.Button("Infer")
    send_btn.click(fn=infer, inputs=[model, input_image], outputs=[output_image])

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
    gr.HTML("""<ul>""")
    gr.HTML("""<li><a href="https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb" target="_blank">Hands-on tutorial for DETR</a>""")
    gr.HTML("""</ul>""")


#demo.queue()
demo.launch(debug=True)


### EOF ###