xumingyu16
commited on
fix bug when generate
Browse files- modeling_baichuan.py +2 -2
modeling_baichuan.py
CHANGED
@@ -221,10 +221,10 @@ class Attention(nn.Module):
|
|
221 |
# head_dim=self.head_dim, n_head=self.num_heads)
|
222 |
# q, k, v = proj
|
223 |
if past_key_value is None:
|
|
|
|
|
224 |
k = custom_convolution(k, self.K)
|
225 |
v = custom_convolution(v, self.V)
|
226 |
-
self.last_k = k[:,-1:]
|
227 |
-
self.last_v = v[:,-1:]
|
228 |
else:
|
229 |
self.last_k,k = k, self.K[:,:1]*self.last_k + self.K[:,1:]*k
|
230 |
self.last_v,v = v, self.V[:,:1]*self.last_v + self.V[:,1:]*v
|
|
|
221 |
# head_dim=self.head_dim, n_head=self.num_heads)
|
222 |
# q, k, v = proj
|
223 |
if past_key_value is None:
|
224 |
+
self.last_k = k[:,-1:]
|
225 |
+
self.last_v = v[:,-1:]
|
226 |
k = custom_convolution(k, self.K)
|
227 |
v = custom_convolution(v, self.V)
|
|
|
|
|
228 |
else:
|
229 |
self.last_k,k = k, self.K[:,:1]*self.last_k + self.K[:,1:]*k
|
230 |
self.last_v,v = v, self.V[:,:1]*self.last_v + self.V[:,1:]*v
|