daking commited on
Commit
1136774
·
1 Parent(s): 7dda9b2

LLM-foundry update October 30, 2023 21:16:19

Browse files
Files changed (1) hide show
  1. hf_prefixlm_converter.py +2 -242
hf_prefixlm_converter.py CHANGED
@@ -6,23 +6,13 @@ Causal LM to convert it to a Prefix LM.
6
  Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
  and treat the input prompt as the prefix in `generate`.
8
  """
9
- import math
10
- import warnings
11
  from types import MethodType
12
  from typing import Any, List, MutableMapping, Optional, Tuple, Union
13
  import torch
14
- from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
15
- from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
16
- from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
17
- from transformers.models.bloom.modeling_bloom import logging
18
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
19
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
20
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
21
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
22
- from transformers.models.opt.modeling_opt import OPTForCausalLM
23
- from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
24
- from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
25
- logger = logging.get_logger(__name__)
26
  _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
27
  CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
28
 
@@ -110,232 +100,8 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
110
  setattr(model, 'generate', MethodType(generate, model))
111
  setattr(model, '_prefix_lm_converted', True)
112
  return model
113
-
114
- def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
115
- """Converts a BLOOM Causal LM to a Prefix LM.
116
-
117
- Supported HuggingFace model classes:
118
- - `BloomForCausalLM`
119
-
120
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
121
- """
122
- if hasattr(model, '_prefix_lm_converted'):
123
- return model
124
- assert isinstance(model, BloomForCausalLM)
125
- assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'
126
-
127
- def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
128
- combined_attention_mask = None
129
- device = attention_mask.device
130
- (_, src_length) = input_shape
131
- if src_length > 1:
132
- combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
133
- if bidirectional_mask is not None:
134
- assert attention_mask.shape == bidirectional_mask.shape
135
- expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
136
- combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
137
- expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
138
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
139
- return combined_attention_mask
140
-
141
- def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
142
- num_heads = self.config.n_head
143
- closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
144
- base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
145
- powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
146
- slopes = torch.pow(base, powers)
147
- if closest_power_of_2 != num_heads:
148
- extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32)
149
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
150
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
151
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
152
- qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
153
- ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
154
- diffs = qa - ka + key_length - query_length
155
- diffs = -diffs.abs()
156
- alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
157
- alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
158
- return alibi.to(dtype)
159
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
160
-
161
- def transformer_forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments: Any) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
162
- if deprecated_arguments.pop('position_ids', False) is not False:
163
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
164
- if len(deprecated_arguments) > 0:
165
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
166
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
167
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
168
- use_cache = use_cache if use_cache is not None else self.config.use_cache
169
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
170
- if input_ids is not None and inputs_embeds is not None:
171
- raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
172
- elif input_ids is not None:
173
- (batch_size, seq_length) = input_ids.shape
174
- elif inputs_embeds is not None:
175
- (batch_size, seq_length, _) = inputs_embeds.shape
176
- else:
177
- raise ValueError('You have to specify either input_ids or inputs_embeds')
178
- if past_key_values is None:
179
- past_key_values = tuple([None] * len(self.h))
180
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
181
- if inputs_embeds is None:
182
- inputs_embeds = self.word_embeddings(input_ids)
183
- hidden_states = self.word_embeddings_layernorm(inputs_embeds)
184
- presents = () if use_cache else None
185
- all_self_attentions = () if output_attentions else None
186
- all_hidden_states = () if output_hidden_states else None
187
- seq_length_with_past = seq_length
188
- past_key_values_length = 0
189
- if past_key_values[0] is not None:
190
- tmp = past_key_values[0][0]
191
- past_key_values_length = tmp.shape[2]
192
- seq_length_with_past = seq_length_with_past + past_key_values_length
193
- if attention_mask is None:
194
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
195
- else:
196
- attention_mask = attention_mask.to(hidden_states.device)
197
- alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device)
198
- causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length)
199
- for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
200
- if output_hidden_states:
201
- hst = (hidden_states,)
202
- all_hidden_states = all_hidden_states + hst
203
- if self.gradient_checkpointing and self.training:
204
- if use_cache:
205
- logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
206
- use_cache = False
207
-
208
- def create_custom_forward(module: torch.nn.Module):
209
-
210
- def custom_forward(*inputs: Any):
211
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
212
- return custom_forward
213
- outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
214
- else:
215
- outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)
216
- hidden_states = outputs[0]
217
- if use_cache is True:
218
- presents = presents + (outputs[1],)
219
- if output_attentions:
220
- oa = (outputs[2 if use_cache else 1],)
221
- all_self_attentions = all_self_attentions + oa
222
- hidden_states = self.ln_f(hidden_states)
223
- if output_hidden_states:
224
- hst = (hidden_states,)
225
- all_hidden_states = all_hidden_states + hst
226
- if not return_dict:
227
- return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
228
- return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
229
- setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
230
- setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
231
- setattr(model.transformer, 'forward', MethodType(transformer_forward, model.transformer))
232
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
233
-
234
- def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments: Any) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
235
- """Replacement forward method for BloomCausalLM."""
236
- if deprecated_arguments.pop('position_ids', False) is not False:
237
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
238
- if len(deprecated_arguments) > 0:
239
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
240
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
241
- transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
242
- hidden_states = transformer_outputs[0]
243
- lm_logits = self.lm_head(hidden_states)
244
- loss = None
245
- if labels is not None:
246
- shift_logits = lm_logits[..., :-1, :].contiguous()
247
- shift_labels = labels[..., 1:].contiguous()
248
- (batch_size, seq_length, vocab_size) = shift_logits.shape
249
- loss_fct = CrossEntropyLoss()
250
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
251
- if not return_dict:
252
- output = (lm_logits,) + transformer_outputs[1:]
253
- return (loss,) + output if loss is not None else output
254
- return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
255
-
256
- def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs: Any) -> dict:
257
- del kwargs
258
- if past:
259
- input_ids = input_ids[:, -1].unsqueeze(-1)
260
- bidirectional_mask = None
261
- if past[0][0].shape[0] == input_ids.shape[0]:
262
- past = self._convert_to_bloom_cache(past)
263
- else:
264
- bidirectional_mask = torch.ones_like(input_ids)
265
- return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask}
266
- setattr(model, 'forward', MethodType(forward, model))
267
- setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model))
268
- setattr(model, '_prefix_lm_converted', True)
269
- return model
270
-
271
- def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
272
- """Converts an OPT Causal LM to a Prefix LM.
273
-
274
- Supported HuggingFace model classes:
275
- - `OPTForCausalLM`
276
-
277
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
278
- """
279
- if hasattr(model, '_prefix_lm_converted'):
280
- return model
281
- assert isinstance(model, OPTForCausalLM)
282
- assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'
283
- setattr(model, '_original_forward', getattr(model, 'forward'))
284
- setattr(model, '_original_generate', getattr(model, 'generate'))
285
- model.model.decoder.bidirectional_mask = None
286
-
287
- def _prepare_decoder_attention_mask(self: torch.nn.Module, attention_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], inputs_embeds: Optional[torch.Tensor], past_key_values_length: int):
288
- combined_attention_mask = None
289
- if input_shape[-1] > 1:
290
- assert inputs_embeds is not None
291
- if self.bidirectional_mask == 'g':
292
- (bsz, src_length) = input_shape
293
- combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
294
- else:
295
- combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
296
- if self.bidirectional_mask is not None:
297
- assert attention_mask is not None
298
- assert attention_mask.shape == self.bidirectional_mask.shape
299
- expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
300
- combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
301
- if attention_mask is not None:
302
- assert inputs_embeds is not None
303
- expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
304
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
305
- return combined_attention_mask
306
- setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
307
-
308
- def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
309
-
310
- def call_og_forward():
311
- return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
312
- if bidirectional_mask is None:
313
- return call_og_forward()
314
- self.model.decoder.bidirectional_mask = bidirectional_mask
315
- try:
316
- outputs = call_og_forward()
317
- except:
318
- self.model.decoder.bidirectional_mask = None
319
- raise
320
- self.model.decoder.bidirectional_mask = None
321
- return outputs
322
-
323
- def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Any):
324
- """Wraps original generate to enable PrefixLM-style attention."""
325
- self.model.decoder.bidirectional_mask = 'g'
326
- try:
327
- output = self._original_generate(*args, **kwargs)
328
- except:
329
- self.model.decoder.bidirectional_mask = None
330
- raise
331
- self.model.decoder.bidirectional_mask = None
332
- return output
333
- setattr(model, 'forward', MethodType(forward, model))
334
- setattr(model, 'generate', MethodType(generate, model))
335
- setattr(model, '_prefix_lm_converted', True)
336
- return model
337
- _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
338
- CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
339
 
340
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
341
  """Converts a HuggingFace Causal LM to a Prefix LM.
@@ -345,8 +111,6 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
345
  - `GPTNeoForCausalLM`
346
  - `GPTNeoXForCausalLM`
347
  - `GPTJForCausalLM`
348
- - `BloomForCausalLM`
349
- - `OPTForCausalLM`
350
 
351
  Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
352
  `generate` method and/or select underlying methods depending on the model class.
@@ -396,10 +160,6 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
396
  """
397
  if isinstance(model, _SUPPORTED_GPT_MODELS):
398
  return _convert_gpt_causal_lm_to_prefix_lm(model)
399
- elif isinstance(model, BloomForCausalLM):
400
- return _convert_bloom_causal_lm_to_prefix_lm(model)
401
- elif isinstance(model, OPTForCausalLM):
402
- return _convert_opt_causal_lm_to_prefix_lm(model)
403
  else:
404
  raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
405
 
 
6
  Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
  and treat the input prompt as the prefix in `generate`.
8
  """
 
 
9
  from types import MethodType
10
  from typing import Any, List, MutableMapping, Optional, Tuple, Union
11
  import torch
 
 
 
 
12
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
13
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
14
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
15
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
 
 
 
 
16
  _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
17
  CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
18
 
 
100
  setattr(model, 'generate', MethodType(generate, model))
101
  setattr(model, '_prefix_lm_converted', True)
102
  return model
103
+ _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS
104
+ CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
107
  """Converts a HuggingFace Causal LM to a Prefix LM.
 
111
  - `GPTNeoForCausalLM`
112
  - `GPTNeoXForCausalLM`
113
  - `GPTJForCausalLM`
 
 
114
 
115
  Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
116
  `generate` method and/or select underlying methods depending on the model class.
 
160
  """
161
  if isinstance(model, _SUPPORTED_GPT_MODELS):
162
  return _convert_gpt_causal_lm_to_prefix_lm(model)
 
 
 
 
163
  else:
164
  raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
165