wuhp commited on
Commit
625ec9b
·
verified ·
1 Parent(s): 1520350

Update myr1/modeling_deepseek.py

Browse files
Files changed (1) hide show
  1. myr1/modeling_deepseek.py +76 -19
myr1/modeling_deepseek.py CHANGED
@@ -2,9 +2,19 @@
2
  modeling_deepseek.py
3
 
4
  An improved version of the DeepSeekV3 model code with added docstrings, in-line commentary,
5
- some mild refactoring, and suggestions for potential future enhancements. This version is
6
- intended for demonstration and testing. Actual performance gains may vary based on your
7
- environment and training data.
 
 
 
 
 
 
 
 
 
 
8
  """
9
 
10
  import math
@@ -45,7 +55,7 @@ from transformers.utils import (
45
  from transformers.utils.import_utils import is_torch_fx_available
46
 
47
  # Import your configuration
48
- from .configuration_deepseek import DeepseekV3Config
49
 
50
  import torch.distributed as dist
51
  import numpy as np
@@ -330,7 +340,7 @@ class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
330
  self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
331
 
332
 
333
- # ==============================================================================
334
  # General Rotary helper functions
335
  # ==============================================================================
336
 
@@ -438,6 +448,8 @@ class MoEGate(nn.Module):
438
  logits = F.linear(hidden_states.float(), self.weight.float(), None)
439
  if self.scoring_func == "sigmoid":
440
  scores = logits.sigmoid()
 
 
441
  else:
442
  raise NotImplementedError(
443
  f"Unsupported gating scoring function: {self.scoring_func}"
@@ -462,6 +474,9 @@ class MoEGate(nn.Module):
462
  tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
463
  _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
464
  topk_weight = scores_for_choice.gather(1, topk_idx)
 
 
 
465
  else:
466
  raise NotImplementedError(
467
  f"Unsupported topk_method: {self.topk_method}"
@@ -656,6 +671,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
656
  class DeepseekV3Attention(nn.Module):
657
  """
658
  Standard multi-headed attention for Deepseek.
 
 
 
 
659
  """
660
  def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
661
  super().__init__()
@@ -665,6 +684,13 @@ class DeepseekV3Attention(nn.Module):
665
  self.attention_dropout = config.attention_dropout
666
  self.hidden_size = config.hidden_size
667
  self.num_heads = config.num_attention_heads
 
 
 
 
 
 
 
668
 
669
  self.max_position_embeddings = config.max_position_embeddings
670
  self.rope_theta = config.rope_theta
@@ -691,16 +717,16 @@ class DeepseekV3Attention(nn.Module):
691
  config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
692
  )
693
 
694
- # K,V-proj (MQA style)
695
  self.kv_a_proj_with_mqa = nn.Linear(
696
  self.hidden_size,
697
- config.kv_lora_rank + config.qk_rope_head_dim,
698
  bias=config.attention_bias,
699
  )
700
  self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
701
  self.kv_b_proj = nn.Linear(
702
  config.kv_lora_rank,
703
- self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
704
  bias=False,
705
  )
706
 
@@ -731,8 +757,8 @@ class DeepseekV3Attention(nn.Module):
731
  base=self.rope_theta,
732
  )
733
  else:
734
- scaling_type = self.config.rope_scaling["type"]
735
- scaling_factor = self.config.rope_scaling["factor"]
736
 
737
  if scaling_type == "linear":
738
  self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
@@ -782,6 +808,12 @@ class DeepseekV3Attention(nn.Module):
782
  ):
783
  """
784
  Standard forward pass for multi-headed self-attention.
 
 
 
 
 
 
785
  """
786
  if "padding_mask" in kwargs:
787
  warnings.warn(
@@ -798,7 +830,7 @@ class DeepseekV3Attention(nn.Module):
798
  q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
799
  q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
800
 
801
- # MQA: K,V from single projection
802
  compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
803
  compressed_kv, k_pe = torch.split(
804
  compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
@@ -806,7 +838,7 @@ class DeepseekV3Attention(nn.Module):
806
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
807
  kv = (
808
  self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
809
- .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
810
  .transpose(1, 2)
811
  )
812
  k_nope, value_states = torch.split(
@@ -829,10 +861,17 @@ class DeepseekV3Attention(nn.Module):
829
  query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
830
  query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
831
 
832
- key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
833
  key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
834
  key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
835
 
 
 
 
 
 
 
 
836
  if past_key_value is not None:
837
  cache_kwargs = {"sin": sin, "cos": cos} # for RoPE
838
  key_states, value_states = past_key_value.update(
@@ -866,7 +905,7 @@ class DeepseekV3Attention(nn.Module):
866
  class DeepseekV3FlashAttention2(DeepseekV3Attention):
867
  """
868
  DeepseekV3 flash attention module. Inherits the same Q/K/V projections from DeepseekV3Attention.
869
- Only the forward pass changes to use flash_attn APIs.
870
  """
871
  def __init__(self, *args, **kwargs):
872
  super().__init__(*args, **kwargs)
@@ -906,7 +945,7 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
906
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
907
  kv = (
908
  self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
909
- .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
910
  .transpose(1, 2)
911
  )
912
  k_nope, value_states = torch.split(
@@ -923,10 +962,17 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
923
  query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
924
  query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
925
 
926
- key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
927
  key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
928
  key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
929
 
 
 
 
 
 
 
 
930
  if self.q_head_dim != self.v_head_dim:
931
  # Pad if needed
932
  value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
@@ -1091,6 +1137,7 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
1091
  ATTENTION_CLASSES = {
1092
  "eager": DeepseekV3Attention,
1093
  "flash_attention_2": DeepseekV3FlashAttention2,
 
1094
  }
1095
 
1096
 
@@ -1106,7 +1153,7 @@ class DeepseekV3DecoderLayer(nn.Module):
1106
  super().__init__()
1107
  self.hidden_size = config.hidden_size
1108
 
1109
- self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
1110
  config=config, layer_idx=layer_idx
1111
  )
1112
 
@@ -1138,7 +1185,7 @@ class DeepseekV3DecoderLayer(nn.Module):
1138
  **kwargs
1139
  ):
1140
  """
1141
- Forward pass for one Deepseek decoder layer.
1142
  """
1143
  residual = hidden_states
1144
 
@@ -1443,6 +1490,10 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1443
  Args:
1444
  labels (torch.LongTensor of shape (batch_size, sequence_length), optional):
1445
  For computing the language modeling loss. Indices in [0, config.vocab_size] or -100.
 
 
 
 
1446
  """
1447
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1448
  output_hidden_states = (output_hidden_states if output_hidden_states is not None
@@ -1500,6 +1551,12 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1500
  ):
1501
  """
1502
  Prepare inputs during generation loops.
 
 
 
 
 
 
1503
  """
1504
  if past_key_values is not None:
1505
  if isinstance(past_key_values, Cache):
@@ -1672,4 +1729,4 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1672
  past_key_values=transformer_outputs.past_key_values,
1673
  hidden_states=transformer_outputs.hidden_states,
1674
  attentions=transformer_outputs.attentions,
1675
- )
 
2
  modeling_deepseek.py
3
 
4
  An improved version of the DeepSeekV3 model code with added docstrings, in-line commentary,
5
+ some mild refactoring, and suggestions for potential future enhancements for **reasoning** and **efficiency**.
6
+ This version incorporates architectural considerations for enhanced reasoning,
7
+ efficiency improvements like GQA (configurable), and placeholders for more advanced features.
8
+ Actual performance gains may vary based on your environment and training data.
9
+
10
+ **Important Notes:**
11
+
12
+ * **Configuration is Key:** Many reasoning and efficiency improvements are driven by configuration changes in `configuration_deepseek.py`. You will need to update your config file to fully utilize these features. See comments marked with `[CONFIG]` for configuration-related suggestions.
13
+ * **Placeholders:** This code includes placeholders (comments and `TODO`s) for features like Sparse Attention, more advanced MoE gating, and Chain-of-Thought (CoT) prompting. Implementing these fully requires more code modifications and potentially changes to your training/inference pipelines.
14
+ * **Data is Crucial for Reasoning:** Reasoning improvements heavily depend on training data. Consider fine-tuning or pre-training on reasoning-focused datasets and using techniques like Chain-of-Thought data augmentation.
15
+ * **Grouped-Query Attention (GQA):** This version includes comments and configuration hints for GQA. Full GQA implementation would require modifying the attention logic in `DeepseekV3Attention.forward()` to handle grouped K/V heads. Currently, it's MQA-style.
16
+ * **Sparse Attention:** Placeholder for integrating Sparse Attention (e.g., Longformer, BigBird). You would need to implement a `SparseDeepseekV3Attention` class and integrate it.
17
+
18
  """
19
 
20
  import math
 
55
  from transformers.utils.import_utils import is_torch_fx_available
56
 
57
  # Import your configuration
58
+ from .configuration_deepseek import DeepseekV3Config # [CONFIG] Make sure DeepseekV3Config has new parameters
59
 
60
  import torch.distributed as dist
61
  import numpy as np
 
340
  self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
341
 
342
 
343
+ # ==============================================================================
344
  # General Rotary helper functions
345
  # ==============================================================================
346
 
 
448
  logits = F.linear(hidden_states.float(), self.weight.float(), None)
449
  if self.scoring_func == "sigmoid":
450
  scores = logits.sigmoid()
451
+ elif self.scoring_func == "softmax": # [CONFIG] Option for softmax gating
452
+ scores = logits.softmax(dim=-1)
453
  else:
454
  raise NotImplementedError(
455
  f"Unsupported gating scoring function: {self.scoring_func}"
 
474
  tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
475
  _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
476
  topk_weight = scores_for_choice.gather(1, topk_idx)
477
+ elif self.topk_method == "topk_gating": # [CONFIG] Option for simpler top-k gating
478
+ _, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
479
+ topk_weight = torch.gather(scores, dim=-1, index=topk_idx)
480
  else:
481
  raise NotImplementedError(
482
  f"Unsupported topk_method: {self.topk_method}"
 
671
  class DeepseekV3Attention(nn.Module):
672
  """
673
  Standard multi-headed attention for Deepseek.
674
+
675
+ **Reasoning & Efficiency Improvements Considered:**
676
+ * **Grouped-Query Attention (GQA):** Configurable via `config.num_key_value_heads` and `config.num_attention_heads`. If `num_key_value_heads < num_attention_heads`, GQA is implicitly enabled. See comments in `forward()` for GQA implementation hints. [CONFIG]
677
+ * **Sparse Attention:** Placeholder for integration. To use sparse attention, you would need to create a `SparseDeepseekV3Attention` class (e.g., based on LongformerAttention or BigBirdAttention) and replace `DeepseekV3Attention` in `ATTENTION_CLASSES` based on a config flag. [CONFIG]
678
  """
679
  def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
680
  super().__init__()
 
684
  self.attention_dropout = config.attention_dropout
685
  self.hidden_size = config.hidden_size
686
  self.num_heads = config.num_attention_heads
687
+ self.num_key_value_heads = config.num_key_value_heads # [CONFIG] For GQA
688
+
689
+ if self.num_heads % self.num_key_value_heads != 0: # GQA check
690
+ raise ValueError(
691
+ "num_attention_heads must be divisible by num_key_value_heads (for GQA)"
692
+ )
693
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads # For GQA
694
 
695
  self.max_position_embeddings = config.max_position_embeddings
696
  self.rope_theta = config.rope_theta
 
717
  config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
718
  )
719
 
720
+ # K,V-proj (MQA/GQA style)
721
  self.kv_a_proj_with_mqa = nn.Linear(
722
  self.hidden_size,
723
+ config.kv_lora_rank + self.qk_rope_head_dim,
724
  bias=config.attention_bias,
725
  )
726
  self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
727
  self.kv_b_proj = nn.Linear(
728
  config.kv_lora_rank,
729
+ self.num_key_value_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), # [GQA] num_key_value_heads here
730
  bias=False,
731
  )
732
 
 
757
  base=self.rope_theta,
758
  )
759
  else:
760
+ scaling_type = self.config.rope_scaling["type"] # [CONFIG] Rope scaling type
761
+ scaling_factor = self.config.rope_scaling["factor"] # [CONFIG] Rope scaling factor
762
 
763
  if scaling_type == "linear":
764
  self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
 
808
  ):
809
  """
810
  Standard forward pass for multi-headed self-attention.
811
+
812
+ **Grouped-Query Attention (GQA) Implementation Notes:**
813
+ If `num_key_value_heads < num_attention_heads` (GQA is configured):
814
+ 1. `kv` projection will produce `num_key_value_heads` * (head_dim * 2) channels.
815
+ 2. We need to *repeat* the `key_states` and `value_states` `num_key_value_groups` times along the head dimension to match the `num_attention_heads` for query.
816
+ 3. `repeat_kv` utility function is used for this repetition.
817
  """
818
  if "padding_mask" in kwargs:
819
  warnings.warn(
 
830
  q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
831
  q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
832
 
833
+ # MQA/GQA: K,V from single projection
834
  compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
835
  compressed_kv, k_pe = torch.split(
836
  compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
 
838
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
839
  kv = (
840
  self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
841
+ .view(bsz, q_len, self.num_key_value_heads, self.qk_nope_head_dim + self.v_head_dim) # [GQA] num_key_value_heads here
842
  .transpose(1, 2)
843
  )
844
  k_nope, value_states = torch.split(
 
861
  query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
862
  query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
863
 
864
+ key_states = k_pe.new_empty(bsz, self.num_key_value_heads, q_len, self.q_head_dim) # [GQA] num_key_value_heads here
865
  key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
866
  key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
867
 
868
+ value_states = value_states # [GQA] num_key_value_heads is already in value_states
869
+
870
+ # GQA: Repeat K/V states if num_key_value_heads < num_attention_heads
871
+ if self.num_key_value_groups != 1:
872
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
873
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
874
+
875
  if past_key_value is not None:
876
  cache_kwargs = {"sin": sin, "cos": cos} # for RoPE
877
  key_states, value_states = past_key_value.update(
 
905
  class DeepseekV3FlashAttention2(DeepseekV3Attention):
906
  """
907
  DeepseekV3 flash attention module. Inherits the same Q/K/V projections from DeepseekV3Attention.
908
+ Only the forward pass changes to use flash_attn APIs. Supports GQA implicitly through `DeepseekV3Attention`.
909
  """
910
  def __init__(self, *args, **kwargs):
911
  super().__init__(*args, **kwargs)
 
945
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
946
  kv = (
947
  self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
948
+ .view(bsz, q_len, self.num_key_value_heads, self.qk_nope_head_dim + self.v_head_dim) # [GQA] num_key_value_heads here
949
  .transpose(1, 2)
950
  )
951
  k_nope, value_states = torch.split(
 
962
  query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
963
  query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
964
 
965
+ key_states = k_pe.new_empty(bsz, self.num_key_value_heads, q_len, self.q_head_dim) # [GQA] num_key_value_heads here
966
  key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
967
  key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
968
 
969
+ value_states = value_states # [GQA] value_states already has num_key_value_heads
970
+
971
+ # GQA: Repeat K/V states if num_key_value_heads < num_attention_heads
972
+ if self.num_key_value_groups != 1:
973
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
974
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
975
+
976
  if self.q_head_dim != self.v_head_dim:
977
  # Pad if needed
978
  value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
 
1137
  ATTENTION_CLASSES = {
1138
  "eager": DeepseekV3Attention,
1139
  "flash_attention_2": DeepseekV3FlashAttention2,
1140
+ # "sparse_attention": SparseDeepseekV3Attention, # [TODO] Placeholder for Sparse Attention class - implement SparseDeepseekV3Attention
1141
  }
1142
 
1143
 
 
1153
  super().__init__()
1154
  self.hidden_size = config.hidden_size
1155
 
1156
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( # [CONFIG] _attn_implementation to select attention type
1157
  config=config, layer_idx=layer_idx
1158
  )
1159
 
 
1185
  **kwargs
1186
  ):
1187
  """
1188
+ Forward pass for one Deepseek decoder layer.
1189
  """
1190
  residual = hidden_states
1191
 
 
1490
  Args:
1491
  labels (torch.LongTensor of shape (batch_size, sequence_length), optional):
1492
  For computing the language modeling loss. Indices in [0, config.vocab_size] or -100.
1493
+
1494
+ **Reasoning Enhancement Considerations:**
1495
+ * **Chain-of-Thought (CoT) Data:** To effectively improve reasoning, fine-tune this model on datasets that include Chain-of-Thought examples. The model architecture is capable of leveraging CoT if trained appropriately.
1496
+ * **Prompt Engineering for CoT Inference:** During inference, use prompts that encourage the model to generate reasoning steps (e.g., "Let's think step by step...") to elicit Chain-of-Thought behavior.
1497
  """
1498
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1499
  output_hidden_states = (output_hidden_states if output_hidden_states is not None
 
1551
  ):
1552
  """
1553
  Prepare inputs during generation loops.
1554
+
1555
+ **Chain-of-Thought (CoT) Inference Hint:**
1556
+ When using Chain-of-Thought prompting during generation, ensure your prompts are correctly formatted to encourage reasoning.
1557
+ Consider using techniques like:
1558
+ * "Let's think step by step:" prefix in your prompt.
1559
+ * Sampling strategies that encourage diverse outputs for self-consistency decoding.
1560
  """
1561
  if past_key_values is not None:
1562
  if isinstance(past_key_values, Cache):
 
1729
  past_key_values=transformer_outputs.past_key_values,
1730
  hidden_states=transformer_outputs.hidden_states,
1731
  attentions=transformer_outputs.attentions,
1732
+ )