File size: 3,402 Bytes
3370ff8
ac09809
032542c
3370ff8
 
 
 
 
 
 
 
edbbf31
3370ff8
f4de4c9
 
 
 
3370ff8
 
 
 
 
 
 
 
 
2a32e11
1d1bc10
 
 
2a32e11
1d1bc10
 
 
 
 
2a32e11
 
a2c3378
1d1bc10
3370ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f928485
 
3370ff8
 
 
 
 
 
 
ef25264
3370ff8
ef25264
 
 
 
 
3370ff8
ef25264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3370ff8
ef25264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F

from constants import COLORS
from utils import fig2img


def visualize_prediction(
    pil_img, output_dict, threshold=0.7, id2label=None, display_mask=False, mask=None
):
    print(display_mask)
    print(mask)
    print(type(mask))
    
    keep = output_dict["scores"] > threshold
    boxes = output_dict["boxes"][keep].tolist()
    scores = output_dict["scores"][keep].tolist()
    labels = output_dict["labels"][keep].tolist()
    if id2label is not None:
        labels = [id2label[x] for x in labels]

    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(pil_img)
    if display_mask and mask is not None:
        # Convert the mask image to a numpy array
        mask_arr = np.asarray(mask)
        
        # Create a new mask with white objects and black background
        new_mask = np.zeros_like(mask_arr)
        new_mask[mask_arr > 0] = 255
        
        # Convert the numpy array back to a PIL Image
        new_mask = Image.fromarray(new_mask)
        
        # Display the new mask as a semi-transparent overlay
        ax.imshow(new_mask, alpha=0.5, cmap='viridis')
       
    colors = COLORS * 100
    for score, (xmin, ymin, xmax, ymax), label, color in zip(
        scores, boxes, labels, colors
    ):
        ax.add_patch(
            plt.Rectangle(
                (xmin, ymin),
                xmax - xmin,
                ymax - ymin,
                fill=False,
                color=color,
                linewidth=2,
            )
        )
        ax.text(
            xmin,
            ymin,
            f"{score:0.2f}",
            fontsize=8,
            bbox=dict(facecolor="yellow", alpha=0.5),
        )
    ax.axis("off")
    return fig2img(fig)


def visualize_attention_map(pil_img, attention_map):
    # Get the attention map for the last layer
    attention_map = attention_map[-1].detach().cpu()
    
    # Get the number of heads
    n_heads = attention_map.shape[1]
    
    # Calculate the average attention weight for each head
    avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
    
    # Resize the attention map
    resized_attention_weight = F.interpolate(
        avg_attention_weight.unsqueeze(0).unsqueeze(0),
        size=pil_img.size[::-1],
        mode="bicubic",
    ).squeeze().numpy()
    
    # Create a grid of subplots
    fig, axes = plt.subplots(nrows=1, ncols=n_heads, figsize=(n_heads*4, 4))
    
    # Loop through the subplots and plot the attention for each head
    for i, ax in enumerate(axes.flat):
        ax.imshow(pil_img)
        ax.imshow(attention_map[0,i,:,:].squeeze(), alpha=0.7, cmap="viridis")
        ax.set_title(f"Head {i+1}")
        ax.axis("off")
    
    plt.tight_layout()
    
    return fig2img(fig)
    # attention_map = attention_map[-1].detach().cpu()
    # avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
    # avg_attention_weight_resized = (
    #     F.interpolate(
    #         avg_attention_weight.unsqueeze(0).unsqueeze(0),
    #         size=pil_img.size[::-1],
    #         mode="bicubic",
    #     )
    #     .squeeze()
    #     .numpy()
    # )

    # plt.imshow(pil_img)
    # plt.imshow(avg_attention_weight_resized, alpha=0.7, cmap="viridis")
    # plt.axis("off")
    # fig = plt.gcf()
    # return fig2img(fig)