tmm1 commited on
Commit
5d0b27e
·
1 Parent(s): 8cace80

split sdp attn into its own patch

Browse files
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
3
+ """
4
+
5
+ import warnings
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import transformers.models.llama.modeling_llama
11
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
12
+
13
+
14
+ def hijack_llama_sdp_attention():
15
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = (
16
+ sdp_attention_forward
17
+ )
18
+
19
+
20
+ def sdp_attention_forward(
21
+ self,
22
+ hidden_states: torch.Tensor,
23
+ attention_mask: Optional[torch.Tensor] = None,
24
+ position_ids: Optional[torch.LongTensor] = None,
25
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
26
+ output_attentions: bool = False,
27
+ use_cache: bool = False,
28
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
29
+ # pylint: disable=duplicate-code
30
+ bsz, q_len, _ = hidden_states.size()
31
+
32
+ if not hasattr(self, "pretraining_tp"):
33
+ self.pretraining_tp = 1
34
+
35
+ if self.pretraining_tp > 1:
36
+ key_value_slicing = (
37
+ self.num_key_value_heads * self.head_dim
38
+ ) // self.pretraining_tp
39
+ query_slices = self.q_proj.weight.split(
40
+ (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
41
+ )
42
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
43
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
44
+
45
+ query_states = [
46
+ F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
47
+ ]
48
+ query_states = torch.cat(query_states, dim=-1)
49
+
50
+ key_states = [
51
+ F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
52
+ ]
53
+ key_states = torch.cat(key_states, dim=-1)
54
+
55
+ value_states = [
56
+ F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
57
+ ]
58
+ value_states = torch.cat(value_states, dim=-1)
59
+
60
+ else:
61
+ query_states = self.q_proj(hidden_states)
62
+ key_states = self.k_proj(hidden_states)
63
+ value_states = self.v_proj(hidden_states)
64
+
65
+ query_states = query_states.view(
66
+ bsz, q_len, self.num_heads, self.head_dim
67
+ ).transpose(1, 2)
68
+ key_states = key_states.view(
69
+ bsz, q_len, self.num_key_value_heads, self.head_dim
70
+ ).transpose(1, 2)
71
+ value_states = value_states.view(
72
+ bsz, q_len, self.num_key_value_heads, self.head_dim
73
+ ).transpose(1, 2)
74
+ # [bsz, q_len, nh, hd]
75
+ # [bsz, nh, q_len, hd]
76
+
77
+ kv_seq_len = key_states.shape[-2]
78
+ if past_key_value is not None:
79
+ kv_seq_len += past_key_value[0].shape[-2]
80
+
81
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
82
+ query_states, key_states = apply_rotary_pos_emb(
83
+ query_states, key_states, cos, sin, position_ids
84
+ )
85
+ # [bsz, nh, t, hd]
86
+
87
+ if past_key_value is not None:
88
+ # reuse k, v, self_attention
89
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
90
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
91
+
92
+ past_key_value = (key_states, value_states) if use_cache else None
93
+
94
+ # repeat k/v heads if n_kv_heads < n_heads
95
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
96
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
97
+
98
+ if output_attentions:
99
+ warnings.warn(
100
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
101
+ )
102
+
103
+ #
104
+ # sdp-attn start
105
+ #
106
+
107
+ with torch.backends.cuda.sdp_kernel():
108
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
109
+ query_states,
110
+ key_states,
111
+ value_states,
112
+ attn_mask=attention_mask,
113
+ is_causal=False,
114
+ )
115
+
116
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
117
+ raise ValueError(
118
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
119
+ f" {attn_output.size()}"
120
+ )
121
+ attn_output = attn_output.transpose(1, 2)
122
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
123
+
124
+ #
125
+ # sdp-attn end
126
+ #
127
+
128
+ if self.pretraining_tp > 1:
129
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
130
+ o_proj_slices = self.o_proj.weight.split(
131
+ self.hidden_size // self.pretraining_tp, dim=1
132
+ )
133
+ attn_output = sum(
134
+ F.linear(attn_output[i], o_proj_slices[i])
135
+ for i in range(self.pretraining_tp)
136
+ )
137
+ else:
138
+ attn_output = self.o_proj(attn_output)
139
+
140
+ return attn_output, None, past_key_value
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py CHANGED
@@ -21,12 +21,6 @@ def hijack_llama_attention():
21
  transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
22
 
23
 
24
- def hijack_llama_sdp_attention():
25
- transformers.models.llama.modeling_llama.LlamaAttention.forward = (
26
- sdp_attention_forward
27
- )
28
-
29
-
30
  def xformers_forward(
31
  self,
32
  hidden_states: torch.Tensor,
@@ -183,102 +177,3 @@ def xformers_forward(
183
  attn_output = self.o_proj(attn_output)
184
 
185
  return attn_output, attn_weights, past_key_value
186
-
187
-
188
- def sdp_attention_forward(
189
- self,
190
- hidden_states: torch.Tensor,
191
- attention_mask: Optional[torch.Tensor] = None,
192
- position_ids: Optional[torch.LongTensor] = None,
193
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
194
- output_attentions: bool = False,
195
- use_cache: bool = False,
196
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
197
- # pylint: disable=duplicate-code
198
- bsz, q_len, _ = hidden_states.size()
199
-
200
- query_states = (
201
- self.q_proj(hidden_states)
202
- .view(bsz, q_len, self.num_heads, self.head_dim)
203
- .transpose(1, 2)
204
- )
205
- key_states = (
206
- self.k_proj(hidden_states)
207
- .view(bsz, q_len, self.num_heads, self.head_dim)
208
- .transpose(1, 2)
209
- )
210
- value_states = (
211
- self.v_proj(hidden_states)
212
- .view(bsz, q_len, self.num_heads, self.head_dim)
213
- .transpose(1, 2)
214
- )
215
-
216
- kv_seq_len = key_states.shape[-2]
217
- if past_key_value is not None:
218
- kv_seq_len += past_key_value[0].shape[-2]
219
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
220
- (
221
- query_states,
222
- key_states,
223
- ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
224
- query_states, key_states, cos, sin, position_ids
225
- )
226
- # [bsz, nh, t, hd]
227
-
228
- if past_key_value is not None:
229
- # reuse k, v, self_attention
230
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
231
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
232
-
233
- past_key_value = (key_states, value_states) if use_cache else None
234
-
235
- # We only apply sdp attention if we don't need to output the whole attention matrix
236
- if not output_attentions:
237
- with torch.backends.cuda.sdp_kernel():
238
- attn_output = torch.nn.functional.scaled_dot_product_attention(
239
- query_states,
240
- key_states,
241
- value_states,
242
- attn_mask=attention_mask,
243
- is_causal=False,
244
- )
245
- attn_weights = None
246
- else:
247
- attn_weights = torch.matmul(
248
- query_states, key_states.transpose(2, 3)
249
- ) / math.sqrt(self.head_dim)
250
-
251
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
252
- raise ValueError(
253
- f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
254
- f" {attn_weights.size()}"
255
- )
256
-
257
- if attention_mask is not None:
258
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
259
- raise ValueError(
260
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
261
- )
262
- attn_weights = attn_weights + attention_mask
263
- attn_weights = torch.max(
264
- attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
265
- )
266
-
267
- # upcast attention to fp32
268
- attn_weights = nn.functional.softmax(
269
- attn_weights, dim=-1, dtype=torch.float32
270
- ).to(query_states.dtype)
271
- attn_output = torch.matmul(attn_weights, value_states)
272
-
273
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
274
- raise ValueError(
275
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
276
- f" {attn_output.size()}"
277
- )
278
-
279
- attn_output = attn_output.transpose(1, 2)
280
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
281
-
282
- attn_output = self.o_proj(attn_output)
283
-
284
- return attn_output, attn_weights, past_key_value
 
21
  transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
22
 
23
 
 
 
 
 
 
 
24
  def xformers_forward(
25
  self,
26
  hidden_states: torch.Tensor,
 
177
  attn_output = self.o_proj(attn_output)
178
 
179
  return attn_output, attn_weights, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/utils/models.py CHANGED
@@ -112,9 +112,7 @@ def load_model(
112
  LOG.info("patching with xformers attention")
113
  hijack_llama_attention()
114
  elif cfg.is_llama_derived_model and cfg.sdp_attention:
115
- from axolotl.monkeypatch.llama_attn_hijack_xformers import (
116
- hijack_llama_sdp_attention,
117
- )
118
 
119
  LOG.info("patching with sdp attention")
120
  hijack_llama_sdp_attention()
 
112
  LOG.info("patching with xformers attention")
113
  hijack_llama_attention()
114
  elif cfg.is_llama_derived_model and cfg.sdp_attention:
115
+ from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention
 
 
116
 
117
  LOG.info("patching with sdp attention")
118
  hijack_llama_sdp_attention()