Implement MLA inference optimizations to DeepseekV2Attention
#12
by
sy-chen
- opened
- modeling_deepseek.py +18 -28
modeling_deepseek.py
CHANGED
@@ -822,17 +822,10 @@ class DeepseekV2Attention(nn.Module):
|
|
822 |
compressed_kv, k_pe = torch.split(
|
823 |
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
824 |
)
|
|
|
825 |
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
826 |
-
kv = (
|
827 |
-
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
828 |
-
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
829 |
-
.transpose(1, 2)
|
830 |
-
)
|
831 |
|
832 |
-
|
833 |
-
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
834 |
-
)
|
835 |
-
kv_seq_len = value_states.shape[-2]
|
836 |
if past_key_value is not None:
|
837 |
if self.layer_idx is None:
|
838 |
raise ValueError(
|
@@ -841,27 +834,22 @@ class DeepseekV2Attention(nn.Module):
|
|
841 |
"with a layer index."
|
842 |
)
|
843 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
844 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
845 |
|
|
|
846 |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
847 |
|
848 |
-
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
849 |
-
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
850 |
-
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
851 |
-
|
852 |
-
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
853 |
-
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
854 |
-
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
855 |
if past_key_value is not None:
|
856 |
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
857 |
-
|
858 |
-
|
859 |
-
)
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
|
|
|
|
865 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
866 |
raise ValueError(
|
867 |
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
@@ -878,11 +866,13 @@ class DeepseekV2Attention(nn.Module):
|
|
878 |
# upcast attention to fp32
|
879 |
attn_weights = nn.functional.softmax(
|
880 |
attn_weights, dim=-1, dtype=torch.float32
|
881 |
-
).to(
|
882 |
attn_weights = nn.functional.dropout(
|
883 |
attn_weights, p=self.attention_dropout, training=self.training
|
884 |
)
|
885 |
-
attn_output = torch.
|
|
|
|
|
886 |
|
887 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
|
888 |
raise ValueError(
|
@@ -1902,4 +1892,4 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
|
|
1902 |
past_key_values=transformer_outputs.past_key_values,
|
1903 |
hidden_states=transformer_outputs.hidden_states,
|
1904 |
attentions=transformer_outputs.attentions,
|
1905 |
-
)
|
|
|
822 |
compressed_kv, k_pe = torch.split(
|
823 |
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
824 |
)
|
825 |
+
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
826 |
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
827 |
|
828 |
+
kv_seq_len = k_pe.shape[-2]
|
|
|
|
|
|
|
829 |
if past_key_value is not None:
|
830 |
if self.layer_idx is None:
|
831 |
raise ValueError(
|
|
|
834 |
"with a layer index."
|
835 |
)
|
836 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
837 |
|
838 |
+
cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
|
839 |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
840 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
841 |
if past_key_value is not None:
|
842 |
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
843 |
+
compressed_kv = compressed_kv.unsqueeze(1)
|
844 |
+
k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
|
845 |
+
compressed_kv = compressed_kv.squeeze(1)
|
846 |
+
|
847 |
+
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
|
848 |
+
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim,:]
|
849 |
+
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]
|
850 |
+
|
851 |
+
q_nope = torch.matmul(q_nope, q_absorb)
|
852 |
+
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
|
853 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
854 |
raise ValueError(
|
855 |
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
|
866 |
# upcast attention to fp32
|
867 |
attn_weights = nn.functional.softmax(
|
868 |
attn_weights, dim=-1, dtype=torch.float32
|
869 |
+
).to(q_pe.dtype)
|
870 |
attn_weights = nn.functional.dropout(
|
871 |
attn_weights, p=self.attention_dropout, training=self.training
|
872 |
)
|
873 |
+
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
|
874 |
+
|
875 |
+
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
876 |
|
877 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
|
878 |
raise ValueError(
|
|
|
1892 |
past_key_values=transformer_outputs.past_key_values,
|
1893 |
hidden_states=transformer_outputs.hidden_states,
|
1894 |
attentions=transformer_outputs.attentions,
|
1895 |
+
)
|