Spaces:
Runtime error
Runtime error
File size: 3,145 Bytes
3370ff8 edbbf31 3370ff8 edbbf31 3370ff8 f4de4c9 3370ff8 edbbf31 3370ff8 ef25264 3370ff8 ef25264 3370ff8 ef25264 3370ff8 ef25264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from matplotlib import pyplot as plt
import torch
import torch.nn.functional as F
from constants import COLORS
from utils import fig2img
def visualize_mask(mask, img, alpha=0.5):
mask = mask.cpu().squeeze().numpy()
img = img.cpu().squeeze().permute(1, 2, 0).numpy()
plt.imshow(img)
plt.imshow(mask, alpha=alpha)
plt.axis("off")
return fig2img(plt.gcf())
def visualize_prediction(
pil_img, output_dict, threshold=0.7, id2label=None, display_mask=False, mask=None
):
print(display_mask)
print(mask)
print(type(mask))
keep = output_dict["scores"] > threshold
boxes = output_dict["boxes"][keep].tolist()
scores = output_dict["scores"][keep].tolist()
labels = output_dict["labels"][keep].tolist()
if id2label is not None:
labels = [id2label[x] for x in labels]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(pil_img)
if display_mask:
ax.imshow(mask, alpha=0.5)
colors = COLORS * 100
for score, (xmin, ymin, xmax, ymax), label, color in zip(
scores, boxes, labels, colors
):
ax.add_patch(
plt.Rectangle(
(xmin, ymin),
xmax - xmin,
ymax - ymin,
fill=False,
color=color,
linewidth=2,
)
)
ax.text(
xmin,
ymin,
f"{label}: {score:0.2f}",
fontsize=10,
bbox=dict(facecolor="yellow", alpha=0.5),
)
ax.axis("off")
return fig2img(fig)
def visualize_attention_map(pil_img, attention_map):
# Get the attention map for the last layer
attention_map = attention_map[-1].detach().cpu()
# Get the number of heads
n_heads = attention_map.shape[1]
# Calculate the average attention weight for each head
avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
# Resize the attention map
resized_attention_weight = F.interpolate(
avg_attention_weight.unsqueeze(0).unsqueeze(0),
size=pil_img.size[::-1],
mode="bicubic",
).squeeze().numpy()
# Create a grid of subplots
fig, axes = plt.subplots(nrows=1, ncols=n_heads, figsize=(n_heads*4, 4))
# Loop through the subplots and plot the attention for each head
for i, ax in enumerate(axes.flat):
ax.imshow(pil_img)
ax.imshow(attention_map[0,i,:,:].squeeze(), alpha=0.7, cmap="viridis")
ax.set_title(f"Head {i+1}")
ax.axis("off")
plt.tight_layout()
return fig2img(fig)
# attention_map = attention_map[-1].detach().cpu()
# avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
# avg_attention_weight_resized = (
# F.interpolate(
# avg_attention_weight.unsqueeze(0).unsqueeze(0),
# size=pil_img.size[::-1],
# mode="bicubic",
# )
# .squeeze()
# .numpy()
# )
# plt.imshow(pil_img)
# plt.imshow(avg_attention_weight_resized, alpha=0.7, cmap="viridis")
# plt.axis("off")
# fig = plt.gcf()
# return fig2img(fig)
|