yangapku commited on
Commit
6e72378
1 Parent(s): 6fa2bfd

remove fix-sized causal mask

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +3 -76
modeling_qwen.py CHANGED
@@ -395,62 +395,6 @@ class QWenAttention(nn.Module):
395
 
396
  return attn_output, attn_weights
397
 
398
- def _upcast_and_reordered_attn(
399
- self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None
400
- ):
401
- bsz, num_heads, q_seq_len, dk = query.size()
402
- _, _, k_seq_len, _ = key.size()
403
-
404
- attn_weights = torch.empty(
405
- bsz * num_heads,
406
- q_seq_len,
407
- k_seq_len,
408
- dtype=torch.float32,
409
- device=query.device,
410
- )
411
-
412
- scale_factor = 1.0
413
- if self.scale_attn_weights:
414
- scale_factor /= float(value.size(-1)) ** 0.5
415
-
416
- with autocast(enabled=False):
417
- q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
418
- -1, dk, k_seq_len
419
- )
420
- attn_weights = torch.baddbmm(
421
- attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
422
- )
423
- attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
424
-
425
- query_length, key_length = query.size(-2), key.size(-2)
426
- causal_mask = registered_causal_mask[
427
- :, :, key_length - query_length : key_length, :key_length
428
- ]
429
- mask_value = torch.finfo(attn_weights.dtype).min
430
- mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
431
- attn_weights.device
432
- )
433
- attn_weights = torch.where(causal_mask, attn_weights, mask_value)
434
-
435
- if attention_mask is not None:
436
- attn_weights = attn_weights + attention_mask
437
-
438
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
439
-
440
- if attn_weights.dtype != torch.float32:
441
- raise RuntimeError(
442
- "Error with upcasting, attn_weights does not have dtype torch.float32"
443
- )
444
- attn_weights = attn_weights.type(value.dtype)
445
- attn_weights = self.attn_dropout(attn_weights)
446
-
447
- if head_mask is not None:
448
- attn_weights = attn_weights * head_mask
449
-
450
- attn_output = torch.matmul(attn_weights, value)
451
-
452
- return attn_output, attn_weights
453
-
454
  def _split_heads(self, tensor, num_heads, attn_head_size):
455
  new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
456
  tensor = tensor.view(new_shape)
@@ -465,7 +409,6 @@ class QWenAttention(nn.Module):
465
  self,
466
  hidden_states: Optional[Tuple[torch.FloatTensor]],
467
  rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
468
- registered_causal_mask: Optional[torch.Tensor] = None,
469
  layer_past: Optional[Tuple[torch.Tensor]] = None,
470
  attention_mask: Optional[torch.FloatTensor] = None,
471
  head_mask: Optional[torch.FloatTensor] = None,
@@ -558,6 +501,9 @@ class QWenAttention(nn.Module):
558
  q, k, v = query, key, value
559
  attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
560
  else:
 
 
 
561
  query = query.permute(0, 2, 1, 3)
562
  if not self.use_cache_quantization:
563
  key = key.permute(0, 2, 1, 3)
@@ -650,7 +596,6 @@ class QWenBlock(nn.Module):
650
  self,
651
  hidden_states: Optional[Tuple[torch.FloatTensor]],
652
  rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
653
- registered_causal_mask: Optional[torch.Tensor] = None,
654
  layer_past: Optional[Tuple[torch.Tensor]] = None,
655
  attention_mask: Optional[torch.FloatTensor] = None,
656
  head_mask: Optional[torch.FloatTensor] = None,
@@ -664,7 +609,6 @@ class QWenBlock(nn.Module):
664
  attn_outputs = self.attn(
665
  layernorm_output,
666
  rotary_pos_emb_list,
667
- registered_causal_mask=registered_causal_mask,
668
  layer_past=layer_past,
669
  attention_mask=attention_mask,
670
  head_mask=head_mask,
@@ -764,21 +708,6 @@ class QWenModel(QWenPreTrainedModel):
764
 
765
  self.use_flash_attn = config.use_flash_attn
766
  self.is_fp32 = not (config.bf16 or config.fp16)
767
- if (
768
- self.use_flash_attn
769
- and flash_attn_unpadded_func is not None
770
- and not self.is_fp32
771
- ):
772
- self.registered_causal_mask = None
773
- else:
774
- max_positions = config.max_position_embeddings
775
- self.register_buffer(
776
- "registered_causal_mask",
777
- torch.tril(
778
- torch.ones((max_positions, max_positions), dtype=torch.bool)
779
- ).view(1, 1, max_positions, max_positions),
780
- persistent=False,
781
- )
782
 
783
  self.h = nn.ModuleList(
784
  [
@@ -950,7 +879,6 @@ class QWenModel(QWenPreTrainedModel):
950
  create_custom_forward(block),
951
  hidden_states,
952
  rotary_pos_emb_list,
953
- self.registered_causal_mask,
954
  None,
955
  attention_mask,
956
  head_mask[i],
@@ -962,7 +890,6 @@ class QWenModel(QWenPreTrainedModel):
962
  hidden_states,
963
  layer_past=layer_past,
964
  rotary_pos_emb_list=rotary_pos_emb_list,
965
- registered_causal_mask=self.registered_causal_mask,
966
  attention_mask=attention_mask,
967
  head_mask=head_mask[i],
968
  encoder_hidden_states=encoder_hidden_states,
 
395
 
396
  return attn_output, attn_weights
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  def _split_heads(self, tensor, num_heads, attn_head_size):
399
  new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
400
  tensor = tensor.view(new_shape)
 
409
  self,
410
  hidden_states: Optional[Tuple[torch.FloatTensor]],
411
  rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
 
412
  layer_past: Optional[Tuple[torch.Tensor]] = None,
413
  attention_mask: Optional[torch.FloatTensor] = None,
414
  head_mask: Optional[torch.FloatTensor] = None,
 
501
  q, k, v = query, key, value
502
  attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
503
  else:
504
+ registered_causal_mask = torch.tril(
505
+ torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
506
+ ).view(1, 1, key.size(1), key.size(1))
507
  query = query.permute(0, 2, 1, 3)
508
  if not self.use_cache_quantization:
509
  key = key.permute(0, 2, 1, 3)
 
596
  self,
597
  hidden_states: Optional[Tuple[torch.FloatTensor]],
598
  rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
 
599
  layer_past: Optional[Tuple[torch.Tensor]] = None,
600
  attention_mask: Optional[torch.FloatTensor] = None,
601
  head_mask: Optional[torch.FloatTensor] = None,
 
609
  attn_outputs = self.attn(
610
  layernorm_output,
611
  rotary_pos_emb_list,
 
612
  layer_past=layer_past,
613
  attention_mask=attention_mask,
614
  head_mask=head_mask,
 
708
 
709
  self.use_flash_attn = config.use_flash_attn
710
  self.is_fp32 = not (config.bf16 or config.fp16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
 
712
  self.h = nn.ModuleList(
713
  [
 
879
  create_custom_forward(block),
880
  hidden_states,
881
  rotary_pos_emb_list,
 
882
  None,
883
  attention_mask,
884
  head_mask[i],
 
890
  hidden_states,
891
  layer_past=layer_past,
892
  rotary_pos_emb_list=rotary_pos_emb_list,
 
893
  attention_mask=attention_mask,
894
  head_mask=head_mask[i],
895
  encoder_hidden_states=encoder_hidden_states,