tmm1 commited on
Commit
985dcbc
·
1 Parent(s): 5d0b27e

sync xformers patch to follow shared format and be diffable

Browse files
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py CHANGED
@@ -3,13 +3,13 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-g
3
  """
4
 
5
  import logging
6
- import math
7
  from typing import Optional, Tuple
8
 
9
  import torch
10
  import torch.nn.functional as F
11
  import transformers.models.llama.modeling_llama
12
- from torch import nn
13
 
14
  try:
15
  import xformers.ops
@@ -75,15 +75,15 @@ def xformers_forward(
75
  value_states = value_states.view(
76
  bsz, q_len, self.num_key_value_heads, self.head_dim
77
  ).transpose(1, 2)
 
 
78
 
79
  kv_seq_len = key_states.shape[-2]
80
  if past_key_value is not None:
81
  kv_seq_len += past_key_value[0].shape[-2]
 
82
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
83
- (
84
- query_states,
85
- key_states,
86
- ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
87
  query_states, key_states, cos, sin, position_ids
88
  )
89
  # [bsz, nh, t, hd]
@@ -96,74 +96,50 @@ def xformers_forward(
96
  past_key_value = (key_states, value_states) if use_cache else None
97
 
98
  # repeat k/v heads if n_kv_heads < n_heads
99
- key_states = transformers.models.llama.modeling_llama.repeat_kv(
100
- key_states, self.num_key_value_groups
101
- )
102
- value_states = transformers.models.llama.modeling_llama.repeat_kv(
103
- value_states, self.num_key_value_groups
104
- )
 
 
 
 
 
105
 
106
- # We only apply xformers optimizations if we don't need to output the whole attention matrix
107
- if not output_attentions:
108
- query_states = query_states.transpose(1, 2)
109
- key_states = key_states.transpose(1, 2)
110
- value_states = value_states.transpose(1, 2)
111
-
112
- # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
113
- # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
114
- if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
115
- # input and output should be of form (bsz, q_len, num_heads, head_dim)
116
- attn_output = xformers.ops.memory_efficient_attention(
117
- query_states, key_states, value_states, attn_bias=None
118
- )
119
- else:
120
- # input and output should be of form (bsz, q_len, num_heads, head_dim)
121
- attn_output = xformers.ops.memory_efficient_attention(
122
- query_states,
123
- key_states,
124
- value_states,
125
- # attn_bias=attention_mask,
126
- attn_bias=xformers.ops.LowerTriangularMask(),
127
- )
128
- attn_weights = None
129
  else:
130
- attn_weights = torch.matmul(
131
- query_states, key_states.transpose(2, 3)
132
- ) / math.sqrt(self.head_dim)
133
-
134
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
135
- raise ValueError(
136
- f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
137
- f" {attn_weights.size()}"
138
- )
139
-
140
- if attention_mask is not None:
141
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
142
- raise ValueError(
143
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
144
- )
145
- attn_weights = attn_weights + attention_mask
146
- attn_weights = torch.max(
147
- attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
148
- )
149
-
150
- # upcast attention to fp32
151
- attn_weights = nn.functional.softmax(
152
- attn_weights, dim=-1, dtype=torch.float32
153
- ).to(query_states.dtype)
154
- attn_output = torch.matmul(attn_weights, value_states)
155
-
156
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
157
- raise ValueError(
158
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
159
- f" {attn_output.size()}"
160
- )
161
-
162
- attn_output = attn_output.transpose(1, 2).contiguous()
163
- # end x-formers vs. not x-formers if-else block
164
 
 
 
 
 
 
165
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
166
 
 
 
 
 
167
  if self.pretraining_tp > 1:
168
  attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
169
  o_proj_slices = self.o_proj.weight.split(
@@ -176,4 +152,4 @@ def xformers_forward(
176
  else:
177
  attn_output = self.o_proj(attn_output)
178
 
179
- return attn_output, attn_weights, past_key_value
 
3
  """
4
 
5
  import logging
6
+ import warnings
7
  from typing import Optional, Tuple
8
 
9
  import torch
10
  import torch.nn.functional as F
11
  import transformers.models.llama.modeling_llama
12
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
13
 
14
  try:
15
  import xformers.ops
 
75
  value_states = value_states.view(
76
  bsz, q_len, self.num_key_value_heads, self.head_dim
77
  ).transpose(1, 2)
78
+ # [bsz, q_len, nh, hd]
79
+ # [bsz, nh, q_len, hd]
80
 
81
  kv_seq_len = key_states.shape[-2]
82
  if past_key_value is not None:
83
  kv_seq_len += past_key_value[0].shape[-2]
84
+
85
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
86
+ query_states, key_states = apply_rotary_pos_emb(
 
 
 
87
  query_states, key_states, cos, sin, position_ids
88
  )
89
  # [bsz, nh, t, hd]
 
96
  past_key_value = (key_states, value_states) if use_cache else None
97
 
98
  # repeat k/v heads if n_kv_heads < n_heads
99
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
100
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
101
+
102
+ if output_attentions:
103
+ warnings.warn(
104
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
105
+ )
106
+
107
+ #
108
+ # xformers-attn start
109
+ #
110
 
111
+ query_states = query_states.transpose(1, 2)
112
+ key_states = key_states.transpose(1, 2)
113
+ value_states = value_states.transpose(1, 2)
114
+
115
+ # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
116
+ # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
117
+ if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
118
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
119
+ attn_output = xformers.ops.memory_efficient_attention(
120
+ query_states, key_states, value_states, attn_bias=None
121
+ )
 
 
 
 
 
 
 
 
 
 
 
 
122
  else:
123
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
124
+ attn_output = xformers.ops.memory_efficient_attention(
125
+ query_states,
126
+ key_states,
127
+ value_states,
128
+ # attn_bias=attention_mask,
129
+ attn_bias=xformers.ops.LowerTriangularMask(),
130
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
133
+ raise ValueError(
134
+ f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
135
+ f" {attn_output.size()}"
136
+ )
137
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
138
 
139
+ #
140
+ # xformers-attn end
141
+ #
142
+
143
  if self.pretraining_tp > 1:
144
  attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
145
  o_proj_slices = self.o_proj.weight.split(
 
152
  else:
153
  attn_output = self.o_proj(attn_output)
154
 
155
+ return attn_output, None, past_key_value