update modeling_qwen.py
Browse files- assets/wechat.png +0 -0
- modeling_qwen.py +2 -1
assets/wechat.png
CHANGED
modeling_qwen.py
CHANGED
@@ -193,9 +193,10 @@ class FlashSelfAttention(torch.nn.Module):
|
|
193 |
if attention_mask is not None:
|
194 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
195 |
v = v[indices_k]
|
196 |
-
if
|
197 |
q = q[indices_k]
|
198 |
cu_seqlens_q = cu_seqlens_k
|
|
|
199 |
else:
|
200 |
cu_seqlens_k = torch.arange(
|
201 |
0,
|
|
|
193 |
if attention_mask is not None:
|
194 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
195 |
v = v[indices_k]
|
196 |
+
if self.training or q.size(0) == k.size(0):
|
197 |
q = q[indices_k]
|
198 |
cu_seqlens_q = cu_seqlens_k
|
199 |
+
seqlen_q = seqlen_k
|
200 |
else:
|
201 |
cu_seqlens_k = torch.arange(
|
202 |
0,
|