johnhew ivanzhouyq commited on
Commit
a9697cc
·
1 Parent(s): 988edaf

Set more precise shape to the attention weights and outputs (#1)

Browse files

- Set more precise shape to the attention weights and outputs (b2c5167f662e0f000f52ee4dd00e67af376067ca)


Co-authored-by: Ivan Zhou <[email protected]>

Files changed (1) hide show
  1. modeling_backpack_gpt2.py +3 -2
modeling_backpack_gpt2.py CHANGED
@@ -101,13 +101,14 @@ class BackpackWeightNetwork(nn.Module):
101
  super().__init__()
102
  self.n_embd = embed_dim
103
  self.num_senses = num_senses
104
- self.c_attn = nn.Linear(embed_dim, 2*embed_dim)
 
105
  self.softmax_scale = None
106
 
107
  def forward(self, encoded):
108
  b, s, d = encoded.shape
109
  encoded = self.c_attn(encoded) # (b, s, 2*d)
110
- encoded = encoded.reshape(b, s, 2, self.num_senses, d // self.num_senses) #(b, s, 2, nv, d//nv)
111
  batch_size, seqlen = encoded.shape[0], encoded.shape[1]
112
 
113
  # compute scores & mask
 
101
  super().__init__()
102
  self.n_embd = embed_dim
103
  self.num_senses = num_senses
104
+ self.embed_per_sense = embed_dim // num_senses
105
+ self.c_attn = nn.Linear(embed_dim, 2 * num_senses * self.embed_per_sense)
106
  self.softmax_scale = None
107
 
108
  def forward(self, encoded):
109
  b, s, d = encoded.shape
110
  encoded = self.c_attn(encoded) # (b, s, 2*d)
111
+ encoded = encoded.reshape(b, s, 2, self.num_senses, self.embed_per_sense) #(b, s, 2, nv, d//nv)
112
  batch_size, seqlen = encoded.shape[0], encoded.shape[1]
113
 
114
  # compute scores & mask