Spaces:
Runtime error
Runtime error
Add app.py
Browse files- app.py +102 -0
- constants.py +33 -0
- examples/1.jpg +0 -0
- examples/2.jpg +0 -0
- examples/3.jpg +0 -0
- models.py +135 -0
- utils.py +250 -0
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from ultralytics import YOLO
|
4 |
+
from torchvision.transforms.functional import to_tensor
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
import torch
|
7 |
+
import albumentations as A
|
8 |
+
from albumentations.pytorch.transforms import ToTensorV2
|
9 |
+
import pandas as pd
|
10 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
11 |
+
|
12 |
+
from utils import *
|
13 |
+
from models import YOLOStamp, Encoder
|
14 |
+
|
15 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
+
|
17 |
+
|
18 |
+
yolov8 = YOLO(hf_hub_download('stamps-labs/yolov8-finetuned', filename='best.torchscript'), task='detect')
|
19 |
+
|
20 |
+
yolo_stamp = YOLOStamp()
|
21 |
+
yolo_stamp.load_state_dict(torch.load(hf_hub_download('stamps-labs/yolo-stamp', filename='state_dict.pth')))
|
22 |
+
yolo_stamp = yolo_stamp.to(device)
|
23 |
+
yolo_stamp.eval()
|
24 |
+
transform = A.Compose([
|
25 |
+
A.Normalize(),
|
26 |
+
ToTensorV2(p=1.0),
|
27 |
+
])
|
28 |
+
|
29 |
+
vits8 = torch.jit.load(hf_hub_download('stamps-labs/vits8-stamp', filename='vits8stamp-torchscript.pth'))
|
30 |
+
vits8 = vits8.to(device)
|
31 |
+
vits8.eval()
|
32 |
+
|
33 |
+
encoder = Encoder()
|
34 |
+
encoder.load_state_dict(torch.load(hf_hub_download('stamps-labs/vae-encoder', filename='encoder.pth')))
|
35 |
+
encoder = encoder.to(device)
|
36 |
+
encoder.eval()
|
37 |
+
|
38 |
+
|
39 |
+
def predict(image, det_choice, emb_choice):
|
40 |
+
|
41 |
+
shape = torch.tensor(image.size)
|
42 |
+
image = image.convert('RGB')
|
43 |
+
|
44 |
+
if det_choice == 'yolov8':
|
45 |
+
coef = torch.hstack((shape, shape)) / 640
|
46 |
+
image = image.resize((640, 640))
|
47 |
+
boxes = yolov8(image)[0].boxes.xyxy.cpu()
|
48 |
+
image_with_boxes = visualize_bbox(image, boxes)
|
49 |
+
|
50 |
+
elif det_choice == 'yolo-stamp':
|
51 |
+
coef = torch.hstack((shape, shape)) / 448
|
52 |
+
image = image.resize((448, 448))
|
53 |
+
image_tensor = transform(image=np.array(image))['image']
|
54 |
+
output = yolo_stamp(image_tensor.unsqueeze(0).to(device))
|
55 |
+
|
56 |
+
boxes = output_tensor_to_boxes(output[0].detach().cpu())
|
57 |
+
boxes = nonmax_suppression(boxes)
|
58 |
+
boxes = xywh2xyxy(torch.tensor(boxes)[:, :4])
|
59 |
+
image_with_boxes = visualize_bbox(image, boxes)
|
60 |
+
else:
|
61 |
+
return
|
62 |
+
|
63 |
+
|
64 |
+
embeddings = []
|
65 |
+
if emb_choice == 'vits8':
|
66 |
+
for box in boxes:
|
67 |
+
cropped_stamp = to_tensor(image.crop(box.tolist()))
|
68 |
+
embeddings.append(vits8(cropped_stamp.unsqueeze(0).to(device))[0].detach().cpu())
|
69 |
+
|
70 |
+
elif emb_choice == 'vae-encoder':
|
71 |
+
for box in boxes:
|
72 |
+
cropped_stamp = to_tensor(image.crop(box.tolist()).resize((118, 118)))
|
73 |
+
embeddings.append(encoder(cropped_stamp.unsqueeze(0).to(device))[0][0].detach().cpu())
|
74 |
+
|
75 |
+
embeddings = np.stack(embeddings)
|
76 |
+
|
77 |
+
similarities = cosine_similarity(embeddings)
|
78 |
+
|
79 |
+
boxes = boxes * coef
|
80 |
+
df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
|
81 |
+
|
82 |
+
fig, ax = plt.subplots()
|
83 |
+
im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax,
|
84 |
+
cmap="YlGn", cbarlabel="Embeddings similarities")
|
85 |
+
texts = annotate_heatmap(im, valfmt="{x:.3f}")
|
86 |
+
return image_with_boxes, df_boxes, embeddings, fig
|
87 |
+
|
88 |
+
|
89 |
+
examples = [['./examples/1.jpg', 'yolov8', 'vits8'], ['./examples/2.jpg', 'yolov8', 'vae-encoder'], ['./examples/3.jpg', 'yolo-stamp', 'vits8']]
|
90 |
+
inputs = [
|
91 |
+
gr.Image(type="pil"),
|
92 |
+
gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'),
|
93 |
+
gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'),
|
94 |
+
]
|
95 |
+
outputs = [
|
96 |
+
gr.Image(type="pil"),
|
97 |
+
gr.DataFrame(type='pandas', label="Bounding boxes"),
|
98 |
+
gr.DataFrame(type='numpy', label="Embeddings"),
|
99 |
+
gr.Plot(type='numpy', label="Cosine Similarities")
|
100 |
+
]
|
101 |
+
app = gr.Interface(predict, inputs, outputs, examples=examples)
|
102 |
+
app.launch()
|
constants.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# shape of input image to YOLO
|
2 |
+
W, H = 448, 448
|
3 |
+
# grid size after last convolutional layer of YOLO
|
4 |
+
S = 7
|
5 |
+
# anchors of YOLO model
|
6 |
+
ANCHORS = [[1.5340836003942058, 1.258424277571925],
|
7 |
+
[1.4957766780406023, 2.2319885681948217],
|
8 |
+
[1.2508985343739407, 0.8233350471152914]]
|
9 |
+
# number of anchors boxes
|
10 |
+
BOX = len(ANCHORS)
|
11 |
+
# maximum number of stamps on image
|
12 |
+
STAMP_NB_MAX = 10
|
13 |
+
# minimal confidence of presence a stamp in the grid cell
|
14 |
+
OUTPUT_THRESH = 0.7
|
15 |
+
# maximal iou score to consider boxes different
|
16 |
+
IOU_THRESH = 0.3
|
17 |
+
# path to folder containing images
|
18 |
+
IMAGE_FOLDER = './data/images'
|
19 |
+
# path to .cvs file containing annotations
|
20 |
+
ANNOTATIONS_PATH = './data/all_annotations.csv'
|
21 |
+
# standard deviation and mean of pixel values for normalization
|
22 |
+
STD = (0.229, 0.224, 0.225)
|
23 |
+
MEAN = (0.485, 0.456, 0.406)
|
24 |
+
# box color to show the bounding box on image
|
25 |
+
BOX_COLOR = (0, 0, 255)
|
26 |
+
|
27 |
+
|
28 |
+
# dimenstion of image embedding
|
29 |
+
Z_DIM = 128
|
30 |
+
# hidden dimensions for encoder model
|
31 |
+
ENC_HIDDEN_DIM = 16
|
32 |
+
# hidden dimensions for decoder model
|
33 |
+
DEC_HIDDEN_DIM = 64
|
examples/1.jpg
ADDED
examples/2.jpg
ADDED
examples/3.jpg
ADDED
models.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from constants import *
|
5 |
+
|
6 |
+
"""
|
7 |
+
Class for custom activation.
|
8 |
+
"""
|
9 |
+
class SymReLU(nn.Module):
|
10 |
+
def __init__(self, inplace: bool = False):
|
11 |
+
super().__init__()
|
12 |
+
self.inplace = inplace
|
13 |
+
|
14 |
+
def forward(self, input):
|
15 |
+
return torch.min(torch.max(input, -torch.ones_like(input)), torch.ones_like(input))
|
16 |
+
|
17 |
+
def extra_repr(self) -> str:
|
18 |
+
inplace_str = 'inplace=True' if self.inplace else ''
|
19 |
+
return inplace_str
|
20 |
+
|
21 |
+
|
22 |
+
"""
|
23 |
+
Class implementing YOLO-Stamp architecture described in https://link.springer.com/article/10.1134/S1054661822040046.
|
24 |
+
"""
|
25 |
+
class YOLOStamp(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
anchors=ANCHORS,
|
29 |
+
in_channels=3,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.register_buffer('anchors', torch.tensor(anchors))
|
34 |
+
|
35 |
+
self.act = SymReLU()
|
36 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
37 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
38 |
+
self.norm1 = nn.BatchNorm2d(num_features=8)
|
39 |
+
self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
40 |
+
self.norm2 = nn.BatchNorm2d(num_features=16)
|
41 |
+
self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
42 |
+
self.norm3 = nn.BatchNorm2d(num_features=16)
|
43 |
+
self.conv4 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
44 |
+
self.norm4 = nn.BatchNorm2d(num_features=16)
|
45 |
+
self.conv5 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
46 |
+
self.norm5 = nn.BatchNorm2d(num_features=16)
|
47 |
+
self.conv6 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
48 |
+
self.norm6 = nn.BatchNorm2d(num_features=24)
|
49 |
+
self.conv7 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
50 |
+
self.norm7 = nn.BatchNorm2d(num_features=24)
|
51 |
+
self.conv8 = nn.Conv2d(in_channels=24, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
52 |
+
self.norm8 = nn.BatchNorm2d(num_features=48)
|
53 |
+
self.conv9 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
54 |
+
self.norm9 = nn.BatchNorm2d(num_features=48)
|
55 |
+
self.conv10 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
56 |
+
self.norm10 = nn.BatchNorm2d(num_features=48)
|
57 |
+
self.conv11 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
58 |
+
self.norm11 = nn.BatchNorm2d(num_features=64)
|
59 |
+
self.conv12 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
60 |
+
self.norm12 = nn.BatchNorm2d(num_features=256)
|
61 |
+
self.conv13 = nn.Conv2d(in_channels=256, out_channels=len(anchors) * 5, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
62 |
+
|
63 |
+
def forward(self, x, head=True):
|
64 |
+
x = x.type(self.conv1.weight.dtype)
|
65 |
+
x = self.act(self.pool(self.norm1(self.conv1(x))))
|
66 |
+
x = self.act(self.pool(self.norm2(self.conv2(x))))
|
67 |
+
x = self.act(self.pool(self.norm3(self.conv3(x))))
|
68 |
+
x = self.act(self.pool(self.norm4(self.conv4(x))))
|
69 |
+
x = self.act(self.pool(self.norm5(self.conv5(x))))
|
70 |
+
x = self.act(self.norm6(self.conv6(x)))
|
71 |
+
x = self.act(self.norm7(self.conv7(x)))
|
72 |
+
x = self.act(self.pool(self.norm8(self.conv8(x))))
|
73 |
+
x = self.act(self.norm9(self.conv9(x)))
|
74 |
+
x = self.act(self.norm10(self.conv10(x)))
|
75 |
+
x = self.act(self.norm11(self.conv11(x)))
|
76 |
+
x = self.act(self.norm12(self.conv12(x)))
|
77 |
+
x = self.conv13(x)
|
78 |
+
nb, _, nh, nw= x.shape
|
79 |
+
x = x.permute(0, 2, 3, 1).view(nb, nh, nw, self.anchors.shape[0], 5)
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class Encoder(torch.nn.Module):
|
84 |
+
'''
|
85 |
+
Encoder Class
|
86 |
+
Values:
|
87 |
+
im_chan: the number of channels of the output image, a scalar
|
88 |
+
hidden_dim: the inner dimension, a scalar
|
89 |
+
'''
|
90 |
+
|
91 |
+
def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM):
|
92 |
+
super(Encoder, self).__init__()
|
93 |
+
self.z_dim = output_chan
|
94 |
+
self.disc = torch.nn.Sequential(
|
95 |
+
self.make_disc_block(im_chan, hidden_dim),
|
96 |
+
self.make_disc_block(hidden_dim, hidden_dim * 2),
|
97 |
+
self.make_disc_block(hidden_dim * 2, hidden_dim * 4),
|
98 |
+
self.make_disc_block(hidden_dim * 4, hidden_dim * 8),
|
99 |
+
self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True),
|
100 |
+
)
|
101 |
+
|
102 |
+
def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
|
103 |
+
'''
|
104 |
+
Function to return a sequence of operations corresponding to a encoder block of the VAE,
|
105 |
+
corresponding to a convolution, a batchnorm (except for in the last layer), and an activation
|
106 |
+
Parameters:
|
107 |
+
input_channels: how many channels the input feature representation has
|
108 |
+
output_channels: how many channels the output feature representation should have
|
109 |
+
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
|
110 |
+
stride: the stride of the convolution
|
111 |
+
final_layer: whether we're on the final layer (affects activation and batchnorm)
|
112 |
+
'''
|
113 |
+
if not final_layer:
|
114 |
+
return torch.nn.Sequential(
|
115 |
+
torch.nn.Conv2d(input_channels, output_channels, kernel_size, stride),
|
116 |
+
torch.nn.BatchNorm2d(output_channels),
|
117 |
+
torch.nn.LeakyReLU(0.2, inplace=True),
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
return torch.nn.Sequential(
|
121 |
+
torch.nn.Conv2d(input_channels, output_channels, kernel_size, stride),
|
122 |
+
)
|
123 |
+
|
124 |
+
def forward(self, image):
|
125 |
+
'''
|
126 |
+
Function for completing a forward pass of the Encoder: Given an image tensor,
|
127 |
+
returns a 1-dimension tensor representing fake/real.
|
128 |
+
Parameters:
|
129 |
+
image: a flattened image tensor with dimension (im_dim)
|
130 |
+
'''
|
131 |
+
disc_pred = self.disc(image)
|
132 |
+
encoding = disc_pred.view(len(disc_pred), -1)
|
133 |
+
# The stddev output is treated as the log of the variance of the normal
|
134 |
+
# distribution by convention and for numerical stability
|
135 |
+
return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp()
|
utils.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import matplotlib
|
6 |
+
from constants import *
|
7 |
+
|
8 |
+
def visualize_bbox(image: Image, prediction):
|
9 |
+
img = image.copy()
|
10 |
+
draw = ImageDraw.Draw(img)
|
11 |
+
for i, box in enumerate(prediction):
|
12 |
+
x1, y1, x2, y2 = box.cpu()
|
13 |
+
draw = ImageDraw.Draw(img)
|
14 |
+
text_w, text_h = draw.textsize(str(i + 1))
|
15 |
+
label_y = y1 if y1 <= text_h else y1 - text_h
|
16 |
+
draw.rectangle((x1, y1, x2, y2), outline='red')
|
17 |
+
draw.rectangle((x1, label_y, x1+text_w, label_y+text_h), outline='red', fill='red')
|
18 |
+
draw.text((x1, label_y), str(i + 1), fill='white')
|
19 |
+
return img
|
20 |
+
|
21 |
+
def xywh2xyxy(x):
|
22 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
23 |
+
y[..., 0] = x[..., 0]
|
24 |
+
y[..., 1] = x[..., 1]
|
25 |
+
y[..., 2] = x[..., 0] + x[..., 2]
|
26 |
+
y[..., 3] = x[..., 1] + x[..., 3]
|
27 |
+
return y
|
28 |
+
|
29 |
+
def output_tensor_to_boxes(boxes_tensor):
|
30 |
+
"""
|
31 |
+
Converts the YOLO output tensor to list of boxes with probabilites.
|
32 |
+
|
33 |
+
Arguments:
|
34 |
+
boxes_tensor -- tensor of shape (S, S, BOX, 5)
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
boxes -- list of shape (None, 5)
|
38 |
+
|
39 |
+
Note: "None" is here because you don't know the exact number of selected boxes, as it depends on the threshold.
|
40 |
+
For example, the actual output size of scores would be (10, 5) if there are 10 boxes
|
41 |
+
"""
|
42 |
+
cell_w, cell_h = W/S, H/S
|
43 |
+
boxes = []
|
44 |
+
|
45 |
+
for i in range(S):
|
46 |
+
for j in range(S):
|
47 |
+
for b in range(BOX):
|
48 |
+
anchor_wh = torch.tensor(ANCHORS[b])
|
49 |
+
data = boxes_tensor[i,j,b]
|
50 |
+
xy = torch.sigmoid(data[:2])
|
51 |
+
wh = torch.exp(data[2:4])*anchor_wh
|
52 |
+
obj_prob = torch.sigmoid(data[4])
|
53 |
+
|
54 |
+
if obj_prob > OUTPUT_THRESH:
|
55 |
+
x_center, y_center, w, h = xy[0], xy[1], wh[0], wh[1]
|
56 |
+
x, y = x_center+j-w/2, y_center+i-h/2
|
57 |
+
x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h
|
58 |
+
box = [x,y,w,h, obj_prob]
|
59 |
+
boxes.append(box)
|
60 |
+
return boxes
|
61 |
+
|
62 |
+
def overlap(interval_1, interval_2):
|
63 |
+
"""
|
64 |
+
Calculates length of overlap between two intervals.
|
65 |
+
|
66 |
+
Arguments:
|
67 |
+
interval_1 -- list or tuple of shape (2,) containing endpoints of the first interval
|
68 |
+
interval_2 -- list or tuple of shape (2, 2) containing endpoints of the second interval
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
overlap -- length of overlap
|
72 |
+
"""
|
73 |
+
x1, x2 = interval_1
|
74 |
+
x3, x4 = interval_2
|
75 |
+
if x3 < x1:
|
76 |
+
if x4 < x1:
|
77 |
+
return 0
|
78 |
+
else:
|
79 |
+
return min(x2,x4) - x1
|
80 |
+
else:
|
81 |
+
if x2 < x3:
|
82 |
+
return 0
|
83 |
+
else:
|
84 |
+
return min(x2,x4) - x3
|
85 |
+
|
86 |
+
|
87 |
+
def compute_iou(box1, box2):
|
88 |
+
"""
|
89 |
+
Compute IOU between box1 and box2.
|
90 |
+
|
91 |
+
Argmunets:
|
92 |
+
box1 -- list of shape (5, ). Represents the first box
|
93 |
+
box2 -- list of shape (5, ). Represents the second box
|
94 |
+
Each box is [x, y, w, h, prob]
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
iou -- intersection over union score between two boxes
|
98 |
+
"""
|
99 |
+
x1,y1,w1,h1 = box1[0], box1[1], box1[2], box1[3]
|
100 |
+
x2,y2,w2,h2 = box2[0], box2[1], box2[2], box2[3]
|
101 |
+
|
102 |
+
area1, area2 = w1*h1, w2*h2
|
103 |
+
intersect_w = overlap((x1,x1+w1), (x2,x2+w2))
|
104 |
+
intersect_h = overlap((y1,y1+h1), (y2,y2+w2))
|
105 |
+
if intersect_w == w1 and intersect_h == h1 or intersect_w == w2 and intersect_h == h2:
|
106 |
+
return 1.
|
107 |
+
intersect_area = intersect_w*intersect_h
|
108 |
+
iou = intersect_area/(area1 + area2 - intersect_area)
|
109 |
+
return iou
|
110 |
+
|
111 |
+
|
112 |
+
def nonmax_suppression(boxes, iou_thresh = IOU_THRESH):
|
113 |
+
"""
|
114 |
+
Removes ovelap bboxes
|
115 |
+
|
116 |
+
Arguments:
|
117 |
+
boxes -- list of shape (None, 5)
|
118 |
+
iou_thresh -- maximal value of iou when boxes are considered different
|
119 |
+
Each box is [x, y, w, h, prob]
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
boxes -- list of shape (None, 5) with removed overlapping boxes
|
123 |
+
"""
|
124 |
+
boxes = sorted(boxes, key=lambda x: x[4], reverse=True)
|
125 |
+
for i, current_box in enumerate(boxes):
|
126 |
+
if current_box[4] <= 0:
|
127 |
+
continue
|
128 |
+
for j in range(i+1, len(boxes)):
|
129 |
+
iou = compute_iou(current_box, boxes[j])
|
130 |
+
if iou > iou_thresh:
|
131 |
+
boxes[j][4] = 0
|
132 |
+
boxes = [box for box in boxes if box[4] > 0]
|
133 |
+
return boxes
|
134 |
+
|
135 |
+
def heatmap(data, row_labels, col_labels, ax=None,
|
136 |
+
cbar_kw=None, cbarlabel="", **kwargs):
|
137 |
+
"""
|
138 |
+
Create a heatmap from a numpy array and two lists of labels.
|
139 |
+
|
140 |
+
Parameters
|
141 |
+
----------
|
142 |
+
data
|
143 |
+
A 2D numpy array of shape (M, N).
|
144 |
+
row_labels
|
145 |
+
A list or array of length M with the labels for the rows.
|
146 |
+
col_labels
|
147 |
+
A list or array of length N with the labels for the columns.
|
148 |
+
ax
|
149 |
+
A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
|
150 |
+
not provided, use current axes or create a new one. Optional.
|
151 |
+
cbar_kw
|
152 |
+
A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
|
153 |
+
cbarlabel
|
154 |
+
The label for the colorbar. Optional.
|
155 |
+
**kwargs
|
156 |
+
All other arguments are forwarded to `imshow`.
|
157 |
+
"""
|
158 |
+
|
159 |
+
if ax is None:
|
160 |
+
ax = plt.gca()
|
161 |
+
|
162 |
+
if cbar_kw is None:
|
163 |
+
cbar_kw = {}
|
164 |
+
|
165 |
+
# Plot the heatmap
|
166 |
+
im = ax.imshow(data, **kwargs)
|
167 |
+
|
168 |
+
# Create colorbar
|
169 |
+
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
|
170 |
+
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
|
171 |
+
|
172 |
+
# Show all ticks and label them with the respective list entries.
|
173 |
+
ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
|
174 |
+
ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
|
175 |
+
|
176 |
+
# Let the horizontal axes labeling appear on top.
|
177 |
+
ax.tick_params(top=True, bottom=False,
|
178 |
+
labeltop=True, labelbottom=False)
|
179 |
+
|
180 |
+
# Rotate the tick labels and set their alignment.
|
181 |
+
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
|
182 |
+
rotation_mode="anchor")
|
183 |
+
|
184 |
+
# Turn spines off and create white grid.
|
185 |
+
ax.spines[:].set_visible(False)
|
186 |
+
|
187 |
+
ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
|
188 |
+
ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
|
189 |
+
ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
|
190 |
+
ax.tick_params(which="minor", bottom=False, left=False)
|
191 |
+
|
192 |
+
return im, cbar
|
193 |
+
|
194 |
+
def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
|
195 |
+
textcolors=("black", "white"),
|
196 |
+
threshold=None, **textkw):
|
197 |
+
"""
|
198 |
+
A function to annotate a heatmap.
|
199 |
+
|
200 |
+
Parameters
|
201 |
+
----------
|
202 |
+
im
|
203 |
+
The AxesImage to be labeled.
|
204 |
+
data
|
205 |
+
Data used to annotate. If None, the image's data is used. Optional.
|
206 |
+
valfmt
|
207 |
+
The format of the annotations inside the heatmap. This should either
|
208 |
+
use the string format method, e.g. "$ {x:.2f}", or be a
|
209 |
+
`matplotlib.ticker.Formatter`. Optional.
|
210 |
+
textcolors
|
211 |
+
A pair of colors. The first is used for values below a threshold,
|
212 |
+
the second for those above. Optional.
|
213 |
+
threshold
|
214 |
+
Value in data units according to which the colors from textcolors are
|
215 |
+
applied. If None (the default) uses the middle of the colormap as
|
216 |
+
separation. Optional.
|
217 |
+
**kwargs
|
218 |
+
All other arguments are forwarded to each call to `text` used to create
|
219 |
+
the text labels.
|
220 |
+
"""
|
221 |
+
|
222 |
+
if not isinstance(data, (list, np.ndarray)):
|
223 |
+
data = im.get_array()
|
224 |
+
|
225 |
+
# Normalize the threshold to the images color range.
|
226 |
+
if threshold is not None:
|
227 |
+
threshold = im.norm(threshold)
|
228 |
+
else:
|
229 |
+
threshold = im.norm(data.max())/2.
|
230 |
+
|
231 |
+
# Set default alignment to center, but allow it to be
|
232 |
+
# overwritten by textkw.
|
233 |
+
kw = dict(horizontalalignment="center",
|
234 |
+
verticalalignment="center")
|
235 |
+
kw.update(textkw)
|
236 |
+
|
237 |
+
# Get the formatter in case a string is supplied
|
238 |
+
if isinstance(valfmt, str):
|
239 |
+
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
|
240 |
+
|
241 |
+
# Loop over the data and create a `Text` for each "pixel".
|
242 |
+
# Change the text's color depending on the data.
|
243 |
+
texts = []
|
244 |
+
for i in range(data.shape[0]):
|
245 |
+
for j in range(data.shape[1]):
|
246 |
+
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
|
247 |
+
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
|
248 |
+
texts.append(text)
|
249 |
+
|
250 |
+
return texts
|