Spaces:
Runtime error
Runtime error
import matplotlib | |
import matplotlib.pyplot as plt | |
from PIL import Image, ImageDraw | |
def heatmap(data, row_labels, col_labels, ax=None, | |
cbar_kw=None, cbarlabel="", **kwargs): | |
""" | |
Create a heatmap from a numpy array and two lists of labels. | |
Parameters | |
---------- | |
data | |
A 2D numpy array of shape (M, N). | |
row_labels | |
A list or array of length M with the labels for the rows. | |
col_labels | |
A list or array of length N with the labels for the columns. | |
ax | |
A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If | |
not provided, use current axes or create a new one. Optional. | |
cbar_kw | |
A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. | |
cbarlabel | |
The label for the colorbar. Optional. | |
**kwargs | |
All other arguments are forwarded to `imshow`. | |
""" | |
if ax is None: | |
ax = plt.gca() | |
if cbar_kw is None: | |
cbar_kw = {} | |
# Plot the heatmap | |
im = ax.imshow(data, **kwargs) | |
# Create colorbar | |
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) | |
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") | |
# Show all ticks and label them with the respective list entries. | |
ax.set_xticks(np.arange(data.shape[1]), labels=col_labels) | |
ax.set_yticks(np.arange(data.shape[0]), labels=row_labels) | |
# Let the horizontal axes labeling appear on top. | |
ax.tick_params(top=True, bottom=False, | |
labeltop=True, labelbottom=False) | |
# Rotate the tick labels and set their alignment. | |
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", | |
rotation_mode="anchor") | |
# Turn spines off and create white grid. | |
ax.spines[:].set_visible(False) | |
ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) | |
ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) | |
ax.grid(which="minor", color="w", linestyle='-', linewidth=3) | |
ax.tick_params(which="minor", bottom=False, left=False) | |
return im, cbar | |
def annotate_heatmap(im, data=None, valfmt="{x:.2f}", | |
textcolors=("black", "white"), | |
threshold=None, **textkw): | |
""" | |
A function to annotate a heatmap. | |
Parameters | |
---------- | |
im | |
The AxesImage to be labeled. | |
data | |
Data used to annotate. If None, the image's data is used. Optional. | |
valfmt | |
The format of the annotations inside the heatmap. This should either | |
use the string format method, e.g. "$ {x:.2f}", or be a | |
`matplotlib.ticker.Formatter`. Optional. | |
textcolors | |
A pair of colors. The first is used for values below a threshold, | |
the second for those above. Optional. | |
threshold | |
Value in data units according to which the colors from textcolors are | |
applied. If None (the default) uses the middle of the colormap as | |
separation. Optional. | |
**kwargs | |
All other arguments are forwarded to each call to `text` used to create | |
the text labels. | |
""" | |
if not isinstance(data, (list, np.ndarray)): | |
data = im.get_array() | |
# Normalize the threshold to the images color range. | |
if threshold is not None: | |
threshold = im.norm(threshold) | |
else: | |
threshold = im.norm(data.max())/2. | |
# Set default alignment to center, but allow it to be | |
# overwritten by textkw. | |
kw = dict(horizontalalignment="center", | |
verticalalignment="center") | |
kw.update(textkw) | |
# Get the formatter in case a string is supplied | |
if isinstance(valfmt, str): | |
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) | |
# Loop over the data and create a `Text` for each "pixel". | |
# Change the text's color depending on the data. | |
texts = [] | |
for i in range(data.shape[0]): | |
for j in range(data.shape[1]): | |
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) | |
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) | |
texts.append(text) | |
return texts | |
def visualize_bbox(image: Image, prediction): | |
img = image.copy() | |
draw = ImageDraw.Draw(img) | |
for i, box in enumerate(prediction): | |
x1, y1, x2, y2 = box.cpu() | |
draw = ImageDraw.Draw(img) | |
text_w, text_h = draw.textsize(str(i + 1)) | |
label_y = y1 if y1 <= text_h else y1 - text_h | |
draw.rectangle((x1, y1, x2, y2), outline='red') | |
draw.rectangle((x1, label_y, x1+text_w, label_y+text_h), outline='red', fill='red') | |
draw.text((x1, label_y), str(i + 1), fill='white') | |
return img |