abhi-mosaic commited on
Commit
9e929f5
·
1 Parent(s): 053e1a3
Files changed (4) hide show
  1. attention.py +48 -33
  2. blocks.py +4 -4
  3. configuration_mpt.py +1 -1
  4. modeling_mpt.py +23 -7
attention.py CHANGED
@@ -17,25 +17,34 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau
17
  return False
18
  return original_is_causal
19
 
20
- def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
21
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
22
- k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
23
- v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
24
- min_val = torch.finfo(q.dtype).min
 
 
 
 
 
25
  (b, _, s_q, d) = q.shape
26
  s_k = k.size(-1)
27
  if softmax_scale is None:
28
  softmax_scale = 1 / math.sqrt(d)
29
  attn_weight = q.matmul(k) * softmax_scale
30
  if attn_bias is not None:
 
 
 
31
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
32
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
33
  attn_weight = attn_weight + attn_bias
 
34
  if key_padding_mask is not None:
35
  if attn_bias is not None:
36
  warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
37
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
38
- if is_causal:
39
  s = max(s_q, s_k)
40
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
41
  causal_mask = causal_mask.tril()
@@ -49,8 +58,8 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
49
  out = attn_weight.matmul(v)
50
  out = rearrange(out, 'b h s d -> b s (h d)')
51
  if needs_weights:
52
- return (out, attn_weight)
53
- return (out, None)
54
 
55
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
56
  for tensor in tensors:
@@ -59,12 +68,21 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
59
  if not tensor.is_cuda:
60
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
61
 
62
- def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
63
  try:
64
  from flash_attn import bert_padding, flash_attn_interface
65
  except:
66
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
67
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
68
  if attn_bias is not None:
69
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
70
  (batch_size, seqlen) = query.shape[:2]
@@ -84,9 +102,9 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
84
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
85
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
86
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
87
- return (output, None)
88
 
89
- def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
90
  try:
91
  from .flash_attn_triton import flash_attn_func
92
  except:
@@ -100,6 +118,15 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
100
  if not _installed:
101
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
102
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
103
  if dropout_p:
104
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
105
  if needs_weights:
@@ -119,7 +146,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
119
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
120
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
121
  output = attn_output.view(*attn_output.shape[:2], -1)
122
- return (output, None)
123
 
124
  class MultiheadAttention(nn.Module):
125
  """Multi-head self attention.
@@ -128,7 +155,7 @@ class MultiheadAttention(nn.Module):
128
  additive bias.
129
  """
130
 
131
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
132
  super().__init__()
133
  self.attn_impl = attn_impl
134
  self.clip_qkv = clip_qkv
@@ -150,10 +177,11 @@ class MultiheadAttention(nn.Module):
150
  self.attn_fn = flash_attn_fn
151
  elif self.attn_impl == 'triton':
152
  self.attn_fn = triton_flash_attn_fn
153
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
154
  elif self.attn_impl == 'torch':
155
  self.attn_fn = scaled_multihead_dot_product_attention
156
- if torch.cuda.is_available():
157
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
158
  else:
159
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
@@ -170,14 +198,7 @@ class MultiheadAttention(nn.Module):
170
  dtype = query.dtype
171
  query = self.q_ln(query).to(dtype)
172
  key = self.k_ln(key).to(dtype)
173
- if past_key_value is not None:
174
- if len(past_key_value) != 0:
175
- key = torch.cat([past_key_value[0], key], dim=1)
176
- value = torch.cat([past_key_value[1], value], dim=1)
177
- past_key_value = (key, value)
178
- if attn_bias is not None:
179
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
180
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
181
  return (self.out_proj(context), attn_weights, past_key_value)
182
 
183
  class MultiQueryAttention(nn.Module):
@@ -187,7 +208,7 @@ class MultiQueryAttention(nn.Module):
187
  additive bias.
188
  """
189
 
190
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
191
  super().__init__()
192
  self.attn_impl = attn_impl
193
  self.clip_qkv = clip_qkv
@@ -210,10 +231,11 @@ class MultiQueryAttention(nn.Module):
210
  self.attn_fn = flash_attn_fn
211
  elif self.attn_impl == 'triton':
212
  self.attn_fn = triton_flash_attn_fn
213
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
214
  elif self.attn_impl == 'torch':
215
  self.attn_fn = scaled_multihead_dot_product_attention
216
- if torch.cuda.is_available():
217
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
218
  else:
219
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
@@ -230,14 +252,7 @@ class MultiQueryAttention(nn.Module):
230
  dtype = query.dtype
231
  query = self.q_ln(query).to(dtype)
232
  key = self.k_ln(key).to(dtype)
233
- if past_key_value is not None:
234
- if len(past_key_value) != 0:
235
- key = torch.cat([past_key_value[0], key], dim=1)
236
- value = torch.cat([past_key_value[1], value], dim=1)
237
- past_key_value = (key, value)
238
- if attn_bias is not None:
239
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
240
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
241
  return (self.out_proj(context), attn_weights, past_key_value)
242
 
243
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
 
17
  return False
18
  return original_is_causal
19
 
20
+ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
21
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
22
+ kv_n_heads = 1 if multiquery else n_heads
23
+ k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
24
+ v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
25
+ if past_key_value is not None:
26
+ if len(past_key_value) != 0:
27
+ k = torch.cat([past_key_value[0], k], dim=3)
28
+ v = torch.cat([past_key_value[1], v], dim=2)
29
+ past_key_value = (k, v)
30
  (b, _, s_q, d) = q.shape
31
  s_k = k.size(-1)
32
  if softmax_scale is None:
33
  softmax_scale = 1 / math.sqrt(d)
34
  attn_weight = q.matmul(k) * softmax_scale
35
  if attn_bias is not None:
36
+ _s_q = max(0, attn_bias.size(2) - s_q)
37
+ _s_k = max(0, attn_bias.size(3) - s_k)
38
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
39
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
40
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
41
  attn_weight = attn_weight + attn_bias
42
+ min_val = torch.finfo(q.dtype).min
43
  if key_padding_mask is not None:
44
  if attn_bias is not None:
45
  warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
46
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
47
+ if is_causal and (not q.size(2) == 1):
48
  s = max(s_q, s_k)
49
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
50
  causal_mask = causal_mask.tril()
 
58
  out = attn_weight.matmul(v)
59
  out = rearrange(out, 'b h s d -> b s (h d)')
60
  if needs_weights:
61
+ return (out, attn_weight, past_key_value)
62
+ return (out, None, past_key_value)
63
 
64
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
65
  for tensor in tensors:
 
68
  if not tensor.is_cuda:
69
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
70
 
71
+ def flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
72
  try:
73
  from flash_attn import bert_padding, flash_attn_interface
74
  except:
75
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
76
  check_valid_inputs(query, key, value)
77
+ if past_key_value is not None:
78
+ if len(past_key_value) != 0:
79
+ key = torch.cat([past_key_value[0], key], dim=1)
80
+ value = torch.cat([past_key_value[1], value], dim=1)
81
+ past_key_value = (key, value)
82
+ if attn_bias is not None:
83
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
84
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
85
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
86
  if attn_bias is not None:
87
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
88
  (batch_size, seqlen) = query.shape[:2]
 
102
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
103
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
104
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
105
+ return (output, None, past_key_value)
106
 
107
+ def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
108
  try:
109
  from .flash_attn_triton import flash_attn_func
110
  except:
 
118
  if not _installed:
119
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
120
  check_valid_inputs(query, key, value)
121
+ if past_key_value is not None:
122
+ if len(past_key_value) != 0:
123
+ key = torch.cat([past_key_value[0], key], dim=1)
124
+ value = torch.cat([past_key_value[1], value], dim=1)
125
+ past_key_value = (key, value)
126
+ if attn_bias is not None:
127
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
128
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
129
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
130
  if dropout_p:
131
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
132
  if needs_weights:
 
146
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
147
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
148
  output = attn_output.view(*attn_output.shape[:2], -1)
149
+ return (output, None, past_key_value)
150
 
151
  class MultiheadAttention(nn.Module):
152
  """Multi-head self attention.
 
155
  additive bias.
156
  """
157
 
158
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
159
  super().__init__()
160
  self.attn_impl = attn_impl
161
  self.clip_qkv = clip_qkv
 
177
  self.attn_fn = flash_attn_fn
178
  elif self.attn_impl == 'triton':
179
  self.attn_fn = triton_flash_attn_fn
180
+ if verbose:
181
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
182
  elif self.attn_impl == 'torch':
183
  self.attn_fn = scaled_multihead_dot_product_attention
184
+ if torch.cuda.is_available() and verbose:
185
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
186
  else:
187
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
198
  dtype = query.dtype
199
  query = self.q_ln(query).to(dtype)
200
  key = self.k_ln(key).to(dtype)
201
+ (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
 
 
 
 
 
 
 
202
  return (self.out_proj(context), attn_weights, past_key_value)
203
 
204
  class MultiQueryAttention(nn.Module):
 
208
  additive bias.
209
  """
210
 
211
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
212
  super().__init__()
213
  self.attn_impl = attn_impl
214
  self.clip_qkv = clip_qkv
 
231
  self.attn_fn = flash_attn_fn
232
  elif self.attn_impl == 'triton':
233
  self.attn_fn = triton_flash_attn_fn
234
+ if verbose:
235
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
236
  elif self.attn_impl == 'torch':
237
  self.attn_fn = scaled_multihead_dot_product_attention
238
+ if torch.cuda.is_available() and verbose:
239
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
240
  else:
241
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
252
  dtype = query.dtype
253
  query = self.q_ln(query).to(dtype)
254
  key = self.k_ln(key).to(dtype)
255
+ (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
 
 
 
 
 
 
 
256
  return (self.out_proj(context), attn_weights, past_key_value)
257
 
258
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
blocks.py CHANGED
@@ -19,13 +19,13 @@ class MPTMLP(nn.Module):
19
 
20
  class MPTBlock(nn.Module):
21
 
22
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
- self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
29
  self.norm_2 = norm_class(d_model, device=device)
30
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
@@ -33,9 +33,9 @@ class MPTBlock(nn.Module):
33
 
34
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
- (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
- return (x, past_key_value)
 
19
 
20
  class MPTBlock(nn.Module):
21
 
22
+ def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
+ self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
29
  self.norm_2 = norm_class(d_model, device=device)
30
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
 
33
 
34
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
+ return (x, attn_weights, past_key_value)
configuration_mpt.py CHANGED
@@ -2,7 +2,7 @@
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
- init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
 
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
+ init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
modeling_mpt.py CHANGED
@@ -18,12 +18,16 @@ from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
 
 
 
 
21
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
22
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
26
- _no_split_modules=["MPTBlock"]
27
 
28
  class MPTModel(MPTPreTrainedModel):
29
 
@@ -47,6 +51,7 @@ class MPTModel(MPTPreTrainedModel):
47
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
48
  self.norm_f = norm_class(config.d_model, device=config.init_device)
49
  if config.init_device != 'meta':
 
50
  self.apply(self.param_init_fn)
51
  self.is_causal = not self.prefix_lm
52
  self._attn_bias_initialized = False
@@ -96,7 +101,8 @@ class MPTModel(MPTPreTrainedModel):
96
  if attn_bias is None:
97
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
98
  else:
99
- attn_bias = attn_bias[:, :, :, -s_k:]
 
100
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
101
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
102
  min_val = torch.finfo(attn_bias.dtype).min
@@ -138,7 +144,8 @@ class MPTModel(MPTPreTrainedModel):
138
  if not return_dict:
139
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
140
  if output_attentions:
141
- raise NotImplementedError('output_attentions is not implemented yet for MPT')
 
142
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
143
  raise NotImplementedError('MPT does not support training with left padding.')
144
  if self.prefix_lm and prefix_mask is None:
@@ -159,6 +166,8 @@ class MPTModel(MPTPreTrainedModel):
159
  if len(past_key_values) != self.config.n_layers:
160
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
161
  past_position = past_key_values[0][0].size(1)
 
 
162
  if S + past_position > self.config.max_seq_len:
163
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
164
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
@@ -176,16 +185,23 @@ class MPTModel(MPTPreTrainedModel):
176
  if use_cache and past_key_values is None:
177
  past_key_values = [() for _ in range(self.config.n_layers)]
178
  all_hidden_states = () if output_hidden_states else None
 
179
  for (b_idx, block) in enumerate(self.blocks):
180
  if output_hidden_states:
181
  assert all_hidden_states is not None
182
  all_hidden_states = all_hidden_states + (x,)
183
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
184
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
185
  if past_key_values is not None:
186
  past_key_values[b_idx] = past_key_value
 
 
 
187
  x = self.norm_f(x)
188
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
 
 
 
189
 
190
  def param_init_fn(self, module):
191
  init_fn_name = self.config.init_config['name']
@@ -236,7 +252,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
236
  return_dict = return_dict if return_dict is not None else self.config.return_dict
237
  use_cache = use_cache if use_cache is not None else self.config.use_cache
238
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
239
- logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
240
  if self.logit_scale is not None:
241
  if self.logit_scale == 0:
242
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
@@ -246,7 +262,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
246
  labels = torch.roll(labels, shifts=-1)
247
  labels[:, -1] = -100
248
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
249
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
250
 
251
  def param_init_fn(self, module):
252
  init_fn_name = self.config.init_config['name']
 
18
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
21
+ try:
22
+ from .flash_attn_triton import flash_attn_func
23
+ except:
24
+ pass
25
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
26
 
27
  class MPTPreTrainedModel(PreTrainedModel):
28
  config_class = MPTConfig
29
  base_model_prefix = 'model'
30
+ _no_split_modules = ['MPTBlock']
31
 
32
  class MPTModel(MPTPreTrainedModel):
33
 
 
51
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
52
  self.norm_f = norm_class(config.d_model, device=config.init_device)
53
  if config.init_device != 'meta':
54
+ print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
55
  self.apply(self.param_init_fn)
56
  self.is_causal = not self.prefix_lm
57
  self._attn_bias_initialized = False
 
101
  if attn_bias is None:
102
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
103
  else:
104
+ _s_k = max(0, attn_bias.size(-1) - s_k)
105
+ attn_bias = attn_bias[:, :, :, _s_k:]
106
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
107
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
108
  min_val = torch.finfo(attn_bias.dtype).min
 
144
  if not return_dict:
145
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
146
  if output_attentions:
147
+ if self.attn_impl != 'torch':
148
+ raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
149
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
150
  raise NotImplementedError('MPT does not support training with left padding.')
151
  if self.prefix_lm and prefix_mask is None:
 
166
  if len(past_key_values) != self.config.n_layers:
167
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
168
  past_position = past_key_values[0][0].size(1)
169
+ if self.attn_impl == 'torch':
170
+ past_position = past_key_values[0][0].size(3)
171
  if S + past_position > self.config.max_seq_len:
172
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
173
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
 
185
  if use_cache and past_key_values is None:
186
  past_key_values = [() for _ in range(self.config.n_layers)]
187
  all_hidden_states = () if output_hidden_states else None
188
+ all_self_attns = () if output_attentions else None
189
  for (b_idx, block) in enumerate(self.blocks):
190
  if output_hidden_states:
191
  assert all_hidden_states is not None
192
  all_hidden_states = all_hidden_states + (x,)
193
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
194
+ (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
195
  if past_key_values is not None:
196
  past_key_values[b_idx] = past_key_value
197
+ if output_attentions:
198
+ assert all_self_attns is not None
199
+ all_self_attns = all_self_attns + (attn_weights,)
200
  x = self.norm_f(x)
201
+ if output_hidden_states:
202
+ assert all_hidden_states is not None
203
+ all_hidden_states = all_hidden_states + (x,)
204
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
205
 
206
  def param_init_fn(self, module):
207
  init_fn_name = self.config.init_config['name']
 
252
  return_dict = return_dict if return_dict is not None else self.config.return_dict
253
  use_cache = use_cache if use_cache is not None else self.config.use_cache
254
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
255
+ logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
256
  if self.logit_scale is not None:
257
  if self.logit_scale == 0:
258
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
262
  labels = torch.roll(labels, shifts=-1)
263
  labels[:, -1] = -100
264
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
265
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
266
 
267
  def param_init_fn(self, module):
268
  init_fn_name = self.config.init_config['name']