from matplotlib import pyplot as plt import numpy as np import torch import torch.nn.functional as F from constants import COLORS from utils import fig2img def visualize_prediction( pil_img, output_dict, threshold=0.7, id2label=None, display_mask=False, mask=None ): print(display_mask) print(mask) print(type(mask)) keep = output_dict["scores"] > threshold boxes = output_dict["boxes"][keep].tolist() scores = output_dict["scores"][keep].tolist() labels = output_dict["labels"][keep].tolist() if id2label is not None: labels = [id2label[x] for x in labels] fig, ax = plt.subplots(figsize=(12, 12)) ax.imshow(pil_img) if display_mask and mask is not None: # Create a new mask with white objects and black background new_mask = np.zeros_like(mask) new_mask[mask > 0] = 255 # Display the new mask as a semi-transparent overlay ax.imshow(new_mask, alpha=0.5, cmap='gray') colors = COLORS * 100 for score, (xmin, ymin, xmax, ymax), label, color in zip( scores, boxes, labels, colors ): ax.add_patch( plt.Rectangle( (xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=2, ) ) ax.text( xmin, ymin, f"{label}: {score:0.2f}", fontsize=10, bbox=dict(facecolor="yellow", alpha=0.5), ) ax.axis("off") return fig2img(fig) def visualize_attention_map(pil_img, attention_map): # Get the attention map for the last layer attention_map = attention_map[-1].detach().cpu() # Get the number of heads n_heads = attention_map.shape[1] # Calculate the average attention weight for each head avg_attention_weight = torch.mean(attention_map, dim=1).squeeze() # Resize the attention map resized_attention_weight = F.interpolate( avg_attention_weight.unsqueeze(0).unsqueeze(0), size=pil_img.size[::-1], mode="bicubic", ).squeeze().numpy() # Create a grid of subplots fig, axes = plt.subplots(nrows=1, ncols=n_heads, figsize=(n_heads*4, 4)) # Loop through the subplots and plot the attention for each head for i, ax in enumerate(axes.flat): ax.imshow(pil_img) ax.imshow(attention_map[0,i,:,:].squeeze(), alpha=0.7, cmap="viridis") ax.set_title(f"Head {i+1}") ax.axis("off") plt.tight_layout() return fig2img(fig) # attention_map = attention_map[-1].detach().cpu() # avg_attention_weight = torch.mean(attention_map, dim=1).squeeze() # avg_attention_weight_resized = ( # F.interpolate( # avg_attention_weight.unsqueeze(0).unsqueeze(0), # size=pil_img.size[::-1], # mode="bicubic", # ) # .squeeze() # .numpy() # ) # plt.imshow(pil_img) # plt.imshow(avg_attention_weight_resized, alpha=0.7, cmap="viridis") # plt.axis("off") # fig = plt.gcf() # return fig2img(fig)