Spaces:
Runtime error
Runtime error
from PIL import Image, ImageDraw | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib | |
from constants import * | |
def visualize_bbox(image: Image, prediction): | |
img = image.copy() | |
draw = ImageDraw.Draw(img) | |
for i, box in enumerate(prediction): | |
x1, y1, x2, y2 = box.cpu() | |
draw = ImageDraw.Draw(img) | |
text_w, text_h = draw.textsize(str(i + 1)) | |
label_y = y1 if y1 <= text_h else y1 - text_h | |
draw.rectangle((x1, y1, x2, y2), outline='red') | |
draw.rectangle((x1, label_y, x1+text_w, label_y+text_h), outline='red', fill='red') | |
draw.text((x1, label_y), str(i + 1), fill='white') | |
return img | |
def xywh2xyxy(x): | |
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) | |
y[..., 0] = x[..., 0] | |
y[..., 1] = x[..., 1] | |
y[..., 2] = x[..., 0] + x[..., 2] | |
y[..., 3] = x[..., 1] + x[..., 3] | |
return y | |
def output_tensor_to_boxes(boxes_tensor): | |
""" | |
Converts the YOLO output tensor to list of boxes with probabilites. | |
Arguments: | |
boxes_tensor -- tensor of shape (S, S, BOX, 5) | |
Returns: | |
boxes -- list of shape (None, 5) | |
Note: "None" is here because you don't know the exact number of selected boxes, as it depends on the threshold. | |
For example, the actual output size of scores would be (10, 5) if there are 10 boxes | |
""" | |
cell_w, cell_h = W/S, H/S | |
boxes = [] | |
for i in range(S): | |
for j in range(S): | |
for b in range(BOX): | |
anchor_wh = torch.tensor(ANCHORS[b]) | |
data = boxes_tensor[i,j,b] | |
xy = torch.sigmoid(data[:2]) | |
wh = torch.exp(data[2:4])*anchor_wh | |
obj_prob = torch.sigmoid(data[4]) | |
if obj_prob > OUTPUT_THRESH: | |
x_center, y_center, w, h = xy[0], xy[1], wh[0], wh[1] | |
x, y = x_center+j-w/2, y_center+i-h/2 | |
x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h | |
box = [x,y,w,h, obj_prob] | |
boxes.append(box) | |
return boxes | |
def overlap(interval_1, interval_2): | |
""" | |
Calculates length of overlap between two intervals. | |
Arguments: | |
interval_1 -- list or tuple of shape (2,) containing endpoints of the first interval | |
interval_2 -- list or tuple of shape (2, 2) containing endpoints of the second interval | |
Returns: | |
overlap -- length of overlap | |
""" | |
x1, x2 = interval_1 | |
x3, x4 = interval_2 | |
if x3 < x1: | |
if x4 < x1: | |
return 0 | |
else: | |
return min(x2,x4) - x1 | |
else: | |
if x2 < x3: | |
return 0 | |
else: | |
return min(x2,x4) - x3 | |
def compute_iou(box1, box2): | |
""" | |
Compute IOU between box1 and box2. | |
Argmunets: | |
box1 -- list of shape (5, ). Represents the first box | |
box2 -- list of shape (5, ). Represents the second box | |
Each box is [x, y, w, h, prob] | |
Returns: | |
iou -- intersection over union score between two boxes | |
""" | |
x1,y1,w1,h1 = box1[0], box1[1], box1[2], box1[3] | |
x2,y2,w2,h2 = box2[0], box2[1], box2[2], box2[3] | |
area1, area2 = w1*h1, w2*h2 | |
intersect_w = overlap((x1,x1+w1), (x2,x2+w2)) | |
intersect_h = overlap((y1,y1+h1), (y2,y2+w2)) | |
if intersect_w == w1 and intersect_h == h1 or intersect_w == w2 and intersect_h == h2: | |
return 1. | |
intersect_area = intersect_w*intersect_h | |
iou = intersect_area/(area1 + area2 - intersect_area) | |
return iou | |
def nonmax_suppression(boxes, iou_thresh = IOU_THRESH): | |
""" | |
Removes ovelap bboxes | |
Arguments: | |
boxes -- list of shape (None, 5) | |
iou_thresh -- maximal value of iou when boxes are considered different | |
Each box is [x, y, w, h, prob] | |
Returns: | |
boxes -- list of shape (None, 5) with removed overlapping boxes | |
""" | |
boxes = sorted(boxes, key=lambda x: x[4], reverse=True) | |
for i, current_box in enumerate(boxes): | |
if current_box[4] <= 0: | |
continue | |
for j in range(i+1, len(boxes)): | |
iou = compute_iou(current_box, boxes[j]) | |
if iou > iou_thresh: | |
boxes[j][4] = 0 | |
boxes = [box for box in boxes if box[4] > 0] | |
return boxes | |
def heatmap(data, row_labels, col_labels, ax=None, | |
cbar_kw=None, cbarlabel="", **kwargs): | |
""" | |
Create a heatmap from a numpy array and two lists of labels. | |
Parameters | |
---------- | |
data | |
A 2D numpy array of shape (M, N). | |
row_labels | |
A list or array of length M with the labels for the rows. | |
col_labels | |
A list or array of length N with the labels for the columns. | |
ax | |
A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If | |
not provided, use current axes or create a new one. Optional. | |
cbar_kw | |
A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. | |
cbarlabel | |
The label for the colorbar. Optional. | |
**kwargs | |
All other arguments are forwarded to `imshow`. | |
""" | |
if ax is None: | |
ax = plt.gca() | |
if cbar_kw is None: | |
cbar_kw = {} | |
# Plot the heatmap | |
im = ax.imshow(data, **kwargs) | |
# Create colorbar | |
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) | |
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") | |
# Show all ticks and label them with the respective list entries. | |
ax.set_xticks(np.arange(data.shape[1]), labels=col_labels) | |
ax.set_yticks(np.arange(data.shape[0]), labels=row_labels) | |
# Let the horizontal axes labeling appear on top. | |
ax.tick_params(top=True, bottom=False, | |
labeltop=True, labelbottom=False) | |
# Rotate the tick labels and set their alignment. | |
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", | |
rotation_mode="anchor") | |
# Turn spines off and create white grid. | |
ax.spines[:].set_visible(False) | |
ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) | |
ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) | |
ax.grid(which="minor", color="w", linestyle='-', linewidth=3) | |
ax.tick_params(which="minor", bottom=False, left=False) | |
return im, cbar | |
def annotate_heatmap(im, data=None, valfmt="{x:.2f}", | |
textcolors=("black", "white"), | |
threshold=None, **textkw): | |
""" | |
A function to annotate a heatmap. | |
Parameters | |
---------- | |
im | |
The AxesImage to be labeled. | |
data | |
Data used to annotate. If None, the image's data is used. Optional. | |
valfmt | |
The format of the annotations inside the heatmap. This should either | |
use the string format method, e.g. "$ {x:.2f}", or be a | |
`matplotlib.ticker.Formatter`. Optional. | |
textcolors | |
A pair of colors. The first is used for values below a threshold, | |
the second for those above. Optional. | |
threshold | |
Value in data units according to which the colors from textcolors are | |
applied. If None (the default) uses the middle of the colormap as | |
separation. Optional. | |
**kwargs | |
All other arguments are forwarded to each call to `text` used to create | |
the text labels. | |
""" | |
if not isinstance(data, (list, np.ndarray)): | |
data = im.get_array() | |
# Normalize the threshold to the images color range. | |
if threshold is not None: | |
threshold = im.norm(threshold) | |
else: | |
threshold = im.norm(data.max())/2. | |
# Set default alignment to center, but allow it to be | |
# overwritten by textkw. | |
kw = dict(horizontalalignment="center", | |
verticalalignment="center") | |
kw.update(textkw) | |
# Get the formatter in case a string is supplied | |
if isinstance(valfmt, str): | |
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) | |
# Loop over the data and create a `Text` for each "pixel". | |
# Change the text's color depending on the data. | |
texts = [] | |
for i in range(data.shape[0]): | |
for j in range(data.shape[1]): | |
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) | |
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) | |
texts.append(text) | |
return texts |