Implement MLA inference optimizations to DeepseekV2Attention
Browse filesThis patched DeepseekV2Model contains the following modifications to DeepseekV2Attention for reducing VRAM consumption and improve efficiency:
1. Instead of caching the decompressed key/value states, we cache only the low-rank key-value joint compression as well as
the decoupled RoPE part of the keys. For the sake of reusing the cache utility of transformers library, we treat
k_pe as key_states and compressed_kv as value_states.
2. We implement the absorption technique described in the DeepseekV2 paper, by changing the multiplication order when
computing query and output vectors. This not only saves memory consumption of intermediate tensors but also reduces
the number of floating-point operations.
3. We compute the RoPE part and non-RoPE part of the attention score separately and then sum them up. The original
implementation concatenates the two parts of the query/key vectors, which has proven to be quite inefficient when
caching compressed key/value states due to unnecessary data broadcast and memory round-trips.
By applying the above changes, the MLA module can achieve up to 20.4x speedup for single request and 3.63x for 32
batched requests on an NVIDIA A100-PCIE-40GB GPU during the decoding phase, as well as 26.2x and 3.52x speedup on
NVIDIA GeForce RTX 4080 for single and batched requests, respectively.
More detailed description of the modification can be found in https://zhuanlan.zhihu.com/p/700214123?utm_psn=1779287628619632640 and https://github.com/madsys-dev/deepseekv2-profile/blob/924174cb5dc11fad24bdaad3fd820ebf87506368/workspace/blog/optimizing-mla.md (in Chinese).
- modeling_deepseek.py +18 -28
@@ -822,17 +822,10 @@ class DeepseekV2Attention(nn.Module):
|
|
822 |
compressed_kv, k_pe = torch.split(
|
823 |
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
824 |
)
|
|
|
825 |
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
826 |
-
kv = (
|
827 |
-
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
828 |
-
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
829 |
-
.transpose(1, 2)
|
830 |
-
)
|
831 |
|
832 |
-
|
833 |
-
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
834 |
-
)
|
835 |
-
kv_seq_len = value_states.shape[-2]
|
836 |
if past_key_value is not None:
|
837 |
if self.layer_idx is None:
|
838 |
raise ValueError(
|
@@ -841,27 +834,22 @@ class DeepseekV2Attention(nn.Module):
|
|
841 |
"with a layer index."
|
842 |
)
|
843 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
844 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
845 |
|
|
|
846 |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
847 |
|
848 |
-
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
849 |
-
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
850 |
-
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
851 |
-
|
852 |
-
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
853 |
-
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
854 |
-
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
855 |
if past_key_value is not None:
|
856 |
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
857 |
-
|
858 |
-
|
859 |
-
)
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
|
|
|
|
865 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
866 |
raise ValueError(
|
867 |
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
@@ -878,11 +866,13 @@ class DeepseekV2Attention(nn.Module):
|
|
878 |
# upcast attention to fp32
|
879 |
attn_weights = nn.functional.softmax(
|
880 |
attn_weights, dim=-1, dtype=torch.float32
|
881 |
-
).to(
|
882 |
attn_weights = nn.functional.dropout(
|
883 |
attn_weights, p=self.attention_dropout, training=self.training
|
884 |
)
|
885 |
-
attn_output = torch.
|
|
|
|
|
886 |
|
887 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
|
888 |
raise ValueError(
|
@@ -1902,4 +1892,4 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
|
|
1902 |
past_key_values=transformer_outputs.past_key_values,
|
1903 |
hidden_states=transformer_outputs.hidden_states,
|
1904 |
attentions=transformer_outputs.attentions,
|
1905 |
-
)
|
|
|
822 |
compressed_kv, k_pe = torch.split(
|
823 |
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
824 |
)
|
825 |
+
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
826 |
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
827 |
|
828 |
+
kv_seq_len = k_pe.shape[-2]
|
|
|
|
|
|
|
829 |
if past_key_value is not None:
|
830 |
if self.layer_idx is None:
|
831 |
raise ValueError(
|
|
|
834 |
"with a layer index."
|
835 |
)
|
836 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
837 |
|
838 |
+
cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
|
839 |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
840 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
841 |
if past_key_value is not None:
|
842 |
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
843 |
+
compressed_kv = compressed_kv.unsqueeze(1)
|
844 |
+
k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
|
845 |
+
compressed_kv = compressed_kv.squeeze(1)
|
846 |
+
|
847 |
+
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
|
848 |
+
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim,:]
|
849 |
+
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]
|
850 |
+
|
851 |
+
q_nope = torch.matmul(q_nope, q_absorb)
|
852 |
+
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
|
853 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
854 |
raise ValueError(
|
855 |
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
|
866 |
# upcast attention to fp32
|
867 |
attn_weights = nn.functional.softmax(
|
868 |
attn_weights, dim=-1, dtype=torch.float32
|
869 |
+
).to(q_pe.dtype)
|
870 |
attn_weights = nn.functional.dropout(
|
871 |
attn_weights, p=self.attention_dropout, training=self.training
|
872 |
)
|
873 |
+
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
|
874 |
+
|
875 |
+
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
876 |
|
877 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
|
878 |
raise ValueError(
|
|
|
1892 |
past_key_values=transformer_outputs.past_key_values,
|
1893 |
hidden_states=transformer_outputs.hidden_states,
|
1894 |
attentions=transformer_outputs.attentions,
|
1895 |
+
)
|