xingzhang commited on
Commit
4572596
1 Parent(s): dfa6051

update modeling_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +1 -1
modeling_qwen.py CHANGED
@@ -520,7 +520,7 @@ class QWenAttention(nn.Module):
520
 
521
  if not self.use_cache_quantization and SUPPORT_TORCH2:
522
  if attention_mask is not None:
523
- attention_mask = attention_mask.expand(-1, -1, key_size, -1)
524
  if causal_mask is not None:
525
  attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
526
  else:
 
520
 
521
  if not self.use_cache_quantization and SUPPORT_TORCH2:
522
  if attention_mask is not None:
523
+ attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
524
  if causal_mask is not None:
525
  attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
526
  else: