polejowska commited on
Commit
ef25264
1 Parent(s): e203078

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +43 -14
visualization.py CHANGED
@@ -45,20 +45,49 @@ def visualize_prediction(
45
 
46
 
47
  def visualize_attention_map(pil_img, attention_map):
 
48
  attention_map = attention_map[-1].detach().cpu()
 
 
 
 
 
49
  avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
50
- avg_attention_weight_resized = (
51
- F.interpolate(
52
- avg_attention_weight.unsqueeze(0).unsqueeze(0),
53
- size=pil_img.size[::-1],
54
- mode="bicubic",
55
- )
56
- .squeeze()
57
- .numpy()
58
- )
59
-
60
- plt.imshow(pil_img)
61
- plt.imshow(avg_attention_weight_resized, alpha=0.7, cmap="viridis")
62
- plt.axis("off")
63
- fig = plt.gcf()
 
 
 
 
 
 
64
  return fig2img(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  def visualize_attention_map(pil_img, attention_map):
48
+ # Get the attention map for the last layer
49
  attention_map = attention_map[-1].detach().cpu()
50
+
51
+ # Get the number of heads
52
+ n_heads = attention_map.shape[1]
53
+
54
+ # Calculate the average attention weight for each head
55
  avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
56
+
57
+ # Resize the attention map
58
+ resized_attention_weight = F.interpolate(
59
+ avg_attention_weight.unsqueeze(0).unsqueeze(0),
60
+ size=pil_img.size[::-1],
61
+ mode="bicubic",
62
+ ).squeeze().numpy()
63
+
64
+ # Create a grid of subplots
65
+ fig, axes = plt.subplots(nrows=1, ncols=n_heads, figsize=(n_heads*4, 4))
66
+
67
+ # Loop through the subplots and plot the attention for each head
68
+ for i, ax in enumerate(axes.flat):
69
+ ax.imshow(pil_img)
70
+ ax.imshow(attention_map[0,i,:,:].squeeze(), alpha=0.7, cmap="viridis")
71
+ ax.set_title(f"Head {i+1}")
72
+ ax.axis("off")
73
+
74
+ plt.tight_layout()
75
+
76
  return fig2img(fig)
77
+ # attention_map = attention_map[-1].detach().cpu()
78
+ # avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
79
+ # avg_attention_weight_resized = (
80
+ # F.interpolate(
81
+ # avg_attention_weight.unsqueeze(0).unsqueeze(0),
82
+ # size=pil_img.size[::-1],
83
+ # mode="bicubic",
84
+ # )
85
+ # .squeeze()
86
+ # .numpy()
87
+ # )
88
+
89
+ # plt.imshow(pil_img)
90
+ # plt.imshow(avg_attention_weight_resized, alpha=0.7, cmap="viridis")
91
+ # plt.axis("off")
92
+ # fig = plt.gcf()
93
+ # return fig2img(fig)