sync xformers patch to follow shared format and be diffable
Browse files
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
CHANGED
@@ -3,13 +3,13 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-g
|
|
3 |
"""
|
4 |
|
5 |
import logging
|
6 |
-
import
|
7 |
from typing import Optional, Tuple
|
8 |
|
9 |
import torch
|
10 |
import torch.nn.functional as F
|
11 |
import transformers.models.llama.modeling_llama
|
12 |
-
from
|
13 |
|
14 |
try:
|
15 |
import xformers.ops
|
@@ -75,15 +75,15 @@ def xformers_forward(
|
|
75 |
value_states = value_states.view(
|
76 |
bsz, q_len, self.num_key_value_heads, self.head_dim
|
77 |
).transpose(1, 2)
|
|
|
|
|
78 |
|
79 |
kv_seq_len = key_states.shape[-2]
|
80 |
if past_key_value is not None:
|
81 |
kv_seq_len += past_key_value[0].shape[-2]
|
|
|
82 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
83 |
-
(
|
84 |
-
query_states,
|
85 |
-
key_states,
|
86 |
-
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
87 |
query_states, key_states, cos, sin, position_ids
|
88 |
)
|
89 |
# [bsz, nh, t, hd]
|
@@ -96,74 +96,50 @@ def xformers_forward(
|
|
96 |
past_key_value = (key_states, value_states) if use_cache else None
|
97 |
|
98 |
# repeat k/v heads if n_kv_heads < n_heads
|
99 |
-
key_states =
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
#
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
query_states, key_states, value_states, attn_bias=None
|
118 |
-
)
|
119 |
-
else:
|
120 |
-
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
121 |
-
attn_output = xformers.ops.memory_efficient_attention(
|
122 |
-
query_states,
|
123 |
-
key_states,
|
124 |
-
value_states,
|
125 |
-
# attn_bias=attention_mask,
|
126 |
-
attn_bias=xformers.ops.LowerTriangularMask(),
|
127 |
-
)
|
128 |
-
attn_weights = None
|
129 |
else:
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
)
|
139 |
-
|
140 |
-
if attention_mask is not None:
|
141 |
-
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
142 |
-
raise ValueError(
|
143 |
-
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
144 |
-
)
|
145 |
-
attn_weights = attn_weights + attention_mask
|
146 |
-
attn_weights = torch.max(
|
147 |
-
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
148 |
-
)
|
149 |
-
|
150 |
-
# upcast attention to fp32
|
151 |
-
attn_weights = nn.functional.softmax(
|
152 |
-
attn_weights, dim=-1, dtype=torch.float32
|
153 |
-
).to(query_states.dtype)
|
154 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
155 |
-
|
156 |
-
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
157 |
-
raise ValueError(
|
158 |
-
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
159 |
-
f" {attn_output.size()}"
|
160 |
-
)
|
161 |
-
|
162 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
163 |
-
# end x-formers vs. not x-formers if-else block
|
164 |
|
|
|
|
|
|
|
|
|
|
|
165 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
166 |
|
|
|
|
|
|
|
|
|
167 |
if self.pretraining_tp > 1:
|
168 |
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
169 |
o_proj_slices = self.o_proj.weight.split(
|
@@ -176,4 +152,4 @@ def xformers_forward(
|
|
176 |
else:
|
177 |
attn_output = self.o_proj(attn_output)
|
178 |
|
179 |
-
return attn_output,
|
|
|
3 |
"""
|
4 |
|
5 |
import logging
|
6 |
+
import warnings
|
7 |
from typing import Optional, Tuple
|
8 |
|
9 |
import torch
|
10 |
import torch.nn.functional as F
|
11 |
import transformers.models.llama.modeling_llama
|
12 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
13 |
|
14 |
try:
|
15 |
import xformers.ops
|
|
|
75 |
value_states = value_states.view(
|
76 |
bsz, q_len, self.num_key_value_heads, self.head_dim
|
77 |
).transpose(1, 2)
|
78 |
+
# [bsz, q_len, nh, hd]
|
79 |
+
# [bsz, nh, q_len, hd]
|
80 |
|
81 |
kv_seq_len = key_states.shape[-2]
|
82 |
if past_key_value is not None:
|
83 |
kv_seq_len += past_key_value[0].shape[-2]
|
84 |
+
|
85 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
86 |
+
query_states, key_states = apply_rotary_pos_emb(
|
|
|
|
|
|
|
87 |
query_states, key_states, cos, sin, position_ids
|
88 |
)
|
89 |
# [bsz, nh, t, hd]
|
|
|
96 |
past_key_value = (key_states, value_states) if use_cache else None
|
97 |
|
98 |
# repeat k/v heads if n_kv_heads < n_heads
|
99 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
100 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
101 |
+
|
102 |
+
if output_attentions:
|
103 |
+
warnings.warn(
|
104 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
105 |
+
)
|
106 |
+
|
107 |
+
#
|
108 |
+
# xformers-attn start
|
109 |
+
#
|
110 |
|
111 |
+
query_states = query_states.transpose(1, 2)
|
112 |
+
key_states = key_states.transpose(1, 2)
|
113 |
+
value_states = value_states.transpose(1, 2)
|
114 |
+
|
115 |
+
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
116 |
+
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
117 |
+
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
118 |
+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
119 |
+
attn_output = xformers.ops.memory_efficient_attention(
|
120 |
+
query_states, key_states, value_states, attn_bias=None
|
121 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
else:
|
123 |
+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
124 |
+
attn_output = xformers.ops.memory_efficient_attention(
|
125 |
+
query_states,
|
126 |
+
key_states,
|
127 |
+
value_states,
|
128 |
+
# attn_bias=attention_mask,
|
129 |
+
attn_bias=xformers.ops.LowerTriangularMask(),
|
130 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
133 |
+
raise ValueError(
|
134 |
+
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
135 |
+
f" {attn_output.size()}"
|
136 |
+
)
|
137 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
138 |
|
139 |
+
#
|
140 |
+
# xformers-attn end
|
141 |
+
#
|
142 |
+
|
143 |
if self.pretraining_tp > 1:
|
144 |
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
145 |
o_proj_slices = self.o_proj.weight.split(
|
|
|
152 |
else:
|
153 |
attn_output = self.o_proj(attn_output)
|
154 |
|
155 |
+
return attn_output, None, past_key_value
|