reymondzzzz commited on
Commit
680dd84
1 Parent(s): 1c551a8

Upload GPTRefactForCausalLM

Browse files
config.json CHANGED
@@ -1,29 +1,27 @@
1
  {
2
- "activation_function": "gelu_new",
3
  "architectures": [
4
  "GPTRefactForCausalLM"
5
  ],
6
- "attention_softmax_in_fp32": true,
7
  "attn_pdrop": 0.1,
8
  "auto_map": {
9
  "AutoConfig": "configuration_gpt_refact.GPTRefactConfig",
10
  "AutoModelForCausalLM": "modeling_gpt_refact.GPTRefactForCausalLM"
11
  },
12
- "bos_token_id": 50256,
13
  "do_sample": true,
14
  "embd_pdrop": 0.1,
15
- "eos_token_id": 50256,
16
  "initializer_range": 0.02,
17
  "layer_norm_epsilon": 1e-05,
18
  "model_type": "gpt_refact",
19
- "multi_query": true,
20
  "n_embd": 2048,
21
  "n_head": 32,
22
  "n_inner": null,
23
  "n_layer": 32,
24
  "n_positions": 1024,
25
  "resid_pdrop": 0.1,
26
- "scale_attention_softmax_in_fp32": true,
27
  "scale_attn_weights": true,
28
  "torch_dtype": "float32",
29
  "transformers_version": "4.28.1",
 
1
  {
 
2
  "architectures": [
3
  "GPTRefactForCausalLM"
4
  ],
5
+ "attention_softmax_in_fp32": false,
6
  "attn_pdrop": 0.1,
7
  "auto_map": {
8
  "AutoConfig": "configuration_gpt_refact.GPTRefactConfig",
9
  "AutoModelForCausalLM": "modeling_gpt_refact.GPTRefactForCausalLM"
10
  },
11
+ "bos_token_id": 0,
12
  "do_sample": true,
13
  "embd_pdrop": 0.1,
14
+ "eos_token_id": 0,
15
  "initializer_range": 0.02,
16
  "layer_norm_epsilon": 1e-05,
17
  "model_type": "gpt_refact",
 
18
  "n_embd": 2048,
19
  "n_head": 32,
20
  "n_inner": null,
21
  "n_layer": 32,
22
  "n_positions": 1024,
23
  "resid_pdrop": 0.1,
24
+ "scale_attention_softmax_in_fp32": false,
25
  "scale_attn_weights": true,
26
  "torch_dtype": "float32",
27
  "transformers_version": "4.28.1",
configuration_gpt_refact.py CHANGED
@@ -17,13 +17,12 @@ class GPTRefactConfig(PretrainedConfig):
17
 
18
  def __init__(
19
  self,
20
- vocab_size=50257,
21
  n_positions=1024,
22
  n_embd=768,
23
  n_layer=12,
24
  n_head=12,
25
  n_inner=None,
26
- activation_function="gelu_new",
27
  resid_pdrop=0.1,
28
  embd_pdrop=0.1,
29
  attn_pdrop=0.1,
@@ -31,11 +30,10 @@ class GPTRefactConfig(PretrainedConfig):
31
  initializer_range=0.02,
32
  scale_attn_weights=True,
33
  use_cache=True,
34
- bos_token_id=50256,
35
- eos_token_id=50256,
36
- attention_softmax_in_fp32=True,
37
- scale_attention_softmax_in_fp32=True,
38
- multi_query=True,
39
  **kwargs,
40
  ):
41
  self.vocab_size = vocab_size
@@ -44,7 +42,6 @@ class GPTRefactConfig(PretrainedConfig):
44
  self.n_layer = n_layer
45
  self.n_head = n_head
46
  self.n_inner = n_inner
47
- self.activation_function = activation_function
48
  self.resid_pdrop = resid_pdrop
49
  self.embd_pdrop = embd_pdrop
50
  self.attn_pdrop = attn_pdrop
@@ -54,7 +51,6 @@ class GPTRefactConfig(PretrainedConfig):
54
  self.use_cache = use_cache
55
  self.attention_softmax_in_fp32 = attention_softmax_in_fp32
56
  self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
57
- self.multi_query = multi_query
58
 
59
  self.bos_token_id = bos_token_id
60
  self.eos_token_id = eos_token_id
 
17
 
18
  def __init__(
19
  self,
20
+ vocab_size=49216,
21
  n_positions=1024,
22
  n_embd=768,
23
  n_layer=12,
24
  n_head=12,
25
  n_inner=None,
 
26
  resid_pdrop=0.1,
27
  embd_pdrop=0.1,
28
  attn_pdrop=0.1,
 
30
  initializer_range=0.02,
31
  scale_attn_weights=True,
32
  use_cache=True,
33
+ bos_token_id=0,
34
+ eos_token_id=0,
35
+ attention_softmax_in_fp32=False,
36
+ scale_attention_softmax_in_fp32=False,
 
37
  **kwargs,
38
  ):
39
  self.vocab_size = vocab_size
 
42
  self.n_layer = n_layer
43
  self.n_head = n_head
44
  self.n_inner = n_inner
 
45
  self.resid_pdrop = resid_pdrop
46
  self.embd_pdrop = embd_pdrop
47
  self.attn_pdrop = attn_pdrop
 
51
  self.use_cache = use_cache
52
  self.attention_softmax_in_fp32 = attention_softmax_in_fp32
53
  self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
 
54
 
55
  self.bos_token_id = bos_token_id
56
  self.eos_token_id = eos_token_id
generation_config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_from_model_config": true,
3
- "bos_token_id": 50256,
4
  "do_sample": true,
5
- "eos_token_id": 50256,
6
  "transformers_version": "4.28.1"
7
  }
 
1
  {
2
  "_from_model_config": true,
3
+ "bos_token_id": 0,
4
  "do_sample": true,
5
+ "eos_token_id": 0,
6
  "transformers_version": "4.28.1"
7
  }
modeling_gpt_refact.py CHANGED
@@ -1,39 +1,27 @@
1
  import math
2
- from typing import List, Optional, Tuple, Union
3
-
4
  import torch
 
5
  import torch.utils.checkpoint
6
  from torch import nn
7
- import torch.nn.functional as F
8
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
-
10
- from transformers.activations import ACT2FN
11
  from transformers.modeling_outputs import (
12
  BaseModelOutputWithPastAndCrossAttentions,
13
  CausalLMOutputWithCrossAttentions,
14
- SequenceClassifierOutputWithPast,
15
- TokenClassifierOutput,
16
  )
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import (
19
- add_code_sample_docstrings,
20
- add_start_docstrings,
21
- add_start_docstrings_to_model_forward,
22
  logging,
23
  )
24
- from hf.configuration_gpt_refact import GPTRefactConfig
25
 
 
26
 
27
  logger = logging.get_logger(__name__)
28
 
29
 
30
- # Fused kernels
31
- # Use separate functions for each case because conditionals prevent kernel fusion.
32
- # TODO: Could have better fused kernels depending on scaling, dropout and head mask.
33
- # Is it doable without writing 32 functions?
34
  @torch.jit.script
35
  def upcast_masked_softmax(
36
- x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
37
  ):
38
  input_dtype = x.dtype
39
  x = x.to(softmax_dtype) * scale
@@ -56,7 +44,8 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor
56
  x = torch.nn.functional.softmax(x, dim=-1)
57
  return x
58
 
59
- def _get_slopes(attn_heads: int, dev: str) -> torch.Tensor:
 
60
  """
61
  ## Get head-specific slope $m$ for each head
62
  * `n_heads` is the number of heads in the attention layer $n$
@@ -70,7 +59,7 @@ def _get_slopes(attn_heads: int, dev: str) -> torch.Tensor:
70
  # Get the closest power of 2 to `n_heads`.
71
  # If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2,
72
  # and then add the remaining slopes.
73
- n = 2 ** math.floor(math.log2(attn_heads))
74
  # $2^{-\frac{8}{n}}$
75
  m_0 = 2.0 ** (-8.0 / n)
76
  # $2^{-1\frac{8}{n}}, 2^{-2 \frac{8}{n}}, 2^{-3 \frac{8}{n}}, \dots$
@@ -90,13 +79,13 @@ def _get_slopes(attn_heads: int, dev: str) -> torch.Tensor:
90
 
91
  return m
92
 
93
-
94
  def get_alibi_biases(
95
  B: int,
96
  T: int,
97
  attn_heads: int,
98
- dev: str,
99
- dtype,
100
  causal: bool = True) -> torch.Tensor:
101
  """
102
  ## Calculate the attention biases matrix
@@ -126,12 +115,12 @@ def get_alibi_biases(
126
  biases = biases.repeat(B, 1, 1, 1)
127
  return biases.to(dtype).contiguous()
128
 
 
129
  class Attention(nn.Module):
130
  def __init__(self, config, layer_idx=None):
131
  super().__init__()
132
  self.mask_value = None
133
 
134
- self.multi_query = config.multi_query
135
  self.embed_dim = config.hidden_size
136
  self.num_heads = config.num_attention_heads
137
  self.head_dim = self.embed_dim // self.num_heads
@@ -148,7 +137,7 @@ class Attention(nn.Module):
148
  self.layer_idx = layer_idx
149
  self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
150
  self.scale_attention_softmax_in_fp32 = (
151
- config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
152
  )
153
 
154
  self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
@@ -162,13 +151,9 @@ class Attention(nn.Module):
162
  upcast = dtype != softmax_dtype
163
  unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
164
 
165
- # MQA models: (batch_size, query_length, num_heads * head_dim)
166
- # MHA models: (batch_size, num_heads, query_length, head_dim)
167
  attn_weights = alibi + torch.matmul(query * self.scale, key)
168
 
169
  if upcast:
170
- # Use a fused kernel to prevent a large overhead from casting and scaling.
171
- # Sub-optimal when the key length is not a multiple of 8.
172
  if attention_mask is None:
173
  attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
174
  else:
@@ -176,8 +161,6 @@ class Attention(nn.Module):
176
  attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
177
  else:
178
  if attention_mask is not None:
179
-
180
- # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
181
  attn_weights = torch.masked_fill(attn_weights, attention_mask, -10000)
182
 
183
  attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
@@ -192,15 +175,13 @@ class Attention(nn.Module):
192
  return tensor.permute(0, 2, 1, 3)
193
 
194
  def forward(
195
- self,
196
- hidden_states: torch.Tensor,
197
- layer_past: Optional[torch.Tensor] = None,
198
- attention_mask: Optional[torch.Tensor] = None,
199
- alibi: Optional[torch.Tensor] = None,
200
- encoder_hidden_states: Optional[torch.Tensor] = None,
201
- encoder_attention_mask: Optional[torch.Tensor] = None,
202
- use_cache: Optional[bool] = False,
203
- output_attentions: Optional[bool] = False,
204
  ) -> Union[
205
  Tuple[torch.Tensor, Optional[torch.Tensor]],
206
  Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
@@ -264,6 +245,7 @@ class LayerNormNoBias(nn.Module):
264
  def forward(self, x: torch.Tensor) -> torch.Tensor:
265
  return F.layer_norm(x, self.shape, self.weight, None, self.eps)
266
 
 
267
  class GPTRefactBlock(nn.Module):
268
  def __init__(self, config, layer_idx=None):
269
  super().__init__()
@@ -277,15 +259,13 @@ class GPTRefactBlock(nn.Module):
277
  self.mlp = MLP(self.inner_dim, config)
278
 
279
  def forward(
280
- self,
281
- hidden_states: Optional[Tuple[torch.Tensor]],
282
- layer_past: Optional[torch.Tensor] = None,
283
- attention_mask: Optional[torch.Tensor] = None,
284
- alibi: Optional[torch.Tensor] = None,
285
- encoder_hidden_states: Optional[torch.Tensor] = None,
286
- encoder_attention_mask: Optional[torch.Tensor] = None,
287
- use_cache: Optional[bool] = False,
288
- output_attentions: Optional[bool] = False,
289
  ) -> Union[
290
  Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
291
  ]:
@@ -317,11 +297,6 @@ class GPTRefactBlock(nn.Module):
317
 
318
 
319
  class GPTRefactPreTrainedModel(PreTrainedModel):
320
- """
321
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
322
- models.
323
- """
324
-
325
  config_class = GPTRefactConfig
326
  base_model_prefix = "transformer"
327
  supports_gradient_checkpointing = True
@@ -332,7 +307,6 @@ class GPTRefactPreTrainedModel(PreTrainedModel):
332
  super().__init__(*inputs, **kwargs)
333
 
334
  def _init_weights(self, module):
335
- """Initialize the weights."""
336
  if isinstance(module, (MLP, Attention)):
337
  # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
338
  # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
@@ -354,8 +328,7 @@ class GPTRefactPreTrainedModel(PreTrainedModel):
354
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
355
  if module.padding_idx is not None:
356
  module.weight.data[module.padding_idx].zero_()
357
- elif isinstance(module, nn.LayerNorm):
358
- module.bias.data.zero_()
359
  module.weight.data.fill_(1.0)
360
 
361
  def _set_gradient_checkpointing(self, module, value=False):
@@ -394,20 +367,15 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
394
  return mask
395
 
396
  def forward(
397
- self,
398
- input_ids: Optional[torch.Tensor] = None,
399
- past_key_values: Optional[List[torch.Tensor]] = None,
400
- attention_mask: Optional[torch.Tensor] = None,
401
- token_type_ids: Optional[torch.Tensor] = None,
402
- position_ids: Optional[torch.Tensor] = None,
403
- head_mask: Optional[torch.Tensor] = None,
404
- inputs_embeds: Optional[torch.Tensor] = None,
405
- encoder_hidden_states: Optional[torch.Tensor] = None,
406
- encoder_attention_mask: Optional[torch.Tensor] = None,
407
- use_cache: Optional[bool] = None,
408
- output_attentions: Optional[bool] = None,
409
- output_hidden_states: Optional[bool] = None,
410
- return_dict: Optional[bool] = None,
411
  ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
412
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
413
  output_hidden_states = (
@@ -433,27 +401,12 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
433
 
434
  device = input_ids.device if input_ids is not None else inputs_embeds.device
435
 
436
- if token_type_ids is not None:
437
- token_type_ids = token_type_ids.view(-1, input_shape[-1])
438
- if position_ids is not None:
439
- position_ids = position_ids.view(-1, input_shape[-1])
440
-
441
  if past_key_values is None:
442
  past_length = 0
443
  past_key_values = tuple([None] * len(self.h))
444
  else:
445
  past_length = past_key_values[0][0].size(-2)
446
 
447
- if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
448
- # create position_ids on the fly for batch generation
449
- position_ids = attention_mask.long().cumsum(-1) - 1
450
- position_ids.masked_fill_(attention_mask == 0, 1)
451
- if past_length > 0:
452
- position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
453
- elif position_ids is None:
454
- position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
455
- position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
456
-
457
  # Self-attention mask.
458
  query_length = input_shape[-1]
459
 
@@ -468,10 +421,6 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
468
  alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
469
  self.num_heads, device, self.wte.weight.dtype)[:, :, -query_length:, :]
470
 
471
- if token_type_ids is not None:
472
- token_type_embeds = self.wte(token_type_ids)
473
- hidden_states = hidden_states + token_type_embeds
474
-
475
  output_shape = input_shape + (hidden_states.size(-1),)
476
 
477
  presents = [] if use_cache else None
@@ -496,9 +445,7 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
496
  hidden_states,
497
  None,
498
  attention_mask,
499
- head_mask[i],
500
- encoder_hidden_states,
501
- encoder_attention_mask,
502
  )
503
  else:
504
  outputs = block(
@@ -506,8 +453,6 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
506
  layer_past=layer_past,
507
  attention_mask=attention_mask,
508
  alibi=alibi,
509
- encoder_hidden_states=encoder_hidden_states,
510
- encoder_attention_mask=encoder_attention_mask,
511
  use_cache=use_cache,
512
  output_attentions=output_attentions,
513
  )
@@ -541,21 +486,20 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
541
  cross_attentions=all_cross_attentions,
542
  )
543
 
 
544
  class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
545
- _tied_weights_keys = ["lm_head.weight"]
546
 
547
  def __init__(self, config):
548
  super().__init__(config)
549
  self.transformer = GPTRefactModel(config)
550
- self.ln_f = nn.LayerNorm(self.transformer.embed_dim, eps=config.layer_norm_epsilon)
551
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
552
 
553
  # Initialize weights and apply final processing
554
  self.post_init()
555
 
556
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
557
-
558
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
559
  if inputs_embeds is not None and past_key_values is None:
560
  model_inputs = {"inputs_embeds": inputs_embeds}
561
  else:
@@ -573,21 +517,16 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
573
  return model_inputs
574
 
575
  def forward(
576
- self,
577
- input_ids: Optional[torch.Tensor] = None,
578
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
579
- attention_mask: Optional[torch.Tensor] = None,
580
- token_type_ids: Optional[torch.Tensor] = None,
581
- position_ids: Optional[torch.Tensor] = None,
582
- head_mask: Optional[torch.Tensor] = None,
583
- inputs_embeds: Optional[torch.Tensor] = None,
584
- encoder_hidden_states: Optional[torch.Tensor] = None,
585
- encoder_attention_mask: Optional[torch.Tensor] = None,
586
- labels: Optional[torch.Tensor] = None,
587
- use_cache: Optional[bool] = None,
588
- output_attentions: Optional[bool] = None,
589
- output_hidden_states: Optional[bool] = None,
590
- return_dict: Optional[bool] = None,
591
  ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
592
  r"""
593
  labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -601,12 +540,7 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
601
  input_ids,
602
  past_key_values=past_key_values,
603
  attention_mask=attention_mask,
604
- token_type_ids=token_type_ids,
605
- position_ids=position_ids,
606
- head_mask=head_mask,
607
  inputs_embeds=inputs_embeds,
608
- encoder_hidden_states=encoder_hidden_states,
609
- encoder_attention_mask=encoder_attention_mask,
610
  use_cache=use_cache,
611
  output_attentions=output_attentions,
612
  output_hidden_states=output_hidden_states,
@@ -641,7 +575,7 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
641
 
642
  @staticmethod
643
  def _reorder_cache(
644
- past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
645
  ) -> Tuple[Tuple[torch.Tensor]]:
646
  """
647
  This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
 
1
  import math
 
 
2
  import torch
3
+ import torch.nn.functional as F
4
  import torch.utils.checkpoint
5
  from torch import nn
6
+ from torch.nn import CrossEntropyLoss
 
 
 
7
  from transformers.modeling_outputs import (
8
  BaseModelOutputWithPastAndCrossAttentions,
9
  CausalLMOutputWithCrossAttentions,
 
 
10
  )
11
  from transformers.modeling_utils import PreTrainedModel
12
  from transformers.utils import (
 
 
 
13
  logging,
14
  )
15
+ from typing import List, Optional, Tuple, Union
16
 
17
+ from hf.configuration_gpt_refact import GPTRefactConfig
18
 
19
  logger = logging.get_logger(__name__)
20
 
21
 
 
 
 
 
22
  @torch.jit.script
23
  def upcast_masked_softmax(
24
+ x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
25
  ):
26
  input_dtype = x.dtype
27
  x = x.to(softmax_dtype) * scale
 
44
  x = torch.nn.functional.softmax(x, dim=-1)
45
  return x
46
 
47
+ @torch.jit.script
48
+ def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
49
  """
50
  ## Get head-specific slope $m$ for each head
51
  * `n_heads` is the number of heads in the attention layer $n$
 
59
  # Get the closest power of 2 to `n_heads`.
60
  # If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2,
61
  # and then add the remaining slopes.
62
+ n = 2 ** math.floor(math.log(attn_heads, 2))
63
  # $2^{-\frac{8}{n}}$
64
  m_0 = 2.0 ** (-8.0 / n)
65
  # $2^{-1\frac{8}{n}}, 2^{-2 \frac{8}{n}}, 2^{-3 \frac{8}{n}}, \dots$
 
79
 
80
  return m
81
 
82
+ @torch.jit.script
83
  def get_alibi_biases(
84
  B: int,
85
  T: int,
86
  attn_heads: int,
87
+ dev: torch.device,
88
+ dtype: torch.dtype,
89
  causal: bool = True) -> torch.Tensor:
90
  """
91
  ## Calculate the attention biases matrix
 
115
  biases = biases.repeat(B, 1, 1, 1)
116
  return biases.to(dtype).contiguous()
117
 
118
+
119
  class Attention(nn.Module):
120
  def __init__(self, config, layer_idx=None):
121
  super().__init__()
122
  self.mask_value = None
123
 
 
124
  self.embed_dim = config.hidden_size
125
  self.num_heads = config.num_attention_heads
126
  self.head_dim = self.embed_dim // self.num_heads
 
137
  self.layer_idx = layer_idx
138
  self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
139
  self.scale_attention_softmax_in_fp32 = (
140
+ config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
141
  )
142
 
143
  self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
 
151
  upcast = dtype != softmax_dtype
152
  unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
153
 
 
 
154
  attn_weights = alibi + torch.matmul(query * self.scale, key)
155
 
156
  if upcast:
 
 
157
  if attention_mask is None:
158
  attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
159
  else:
 
161
  attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
162
  else:
163
  if attention_mask is not None:
 
 
164
  attn_weights = torch.masked_fill(attn_weights, attention_mask, -10000)
165
 
166
  attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 
175
  return tensor.permute(0, 2, 1, 3)
176
 
177
  def forward(
178
+ self,
179
+ hidden_states: torch.Tensor,
180
+ layer_past: Optional[torch.Tensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None,
182
+ alibi: Optional[torch.Tensor] = None,
183
+ use_cache: Optional[bool] = False,
184
+ output_attentions: Optional[bool] = False,
 
 
185
  ) -> Union[
186
  Tuple[torch.Tensor, Optional[torch.Tensor]],
187
  Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
 
245
  def forward(self, x: torch.Tensor) -> torch.Tensor:
246
  return F.layer_norm(x, self.shape, self.weight, None, self.eps)
247
 
248
+
249
  class GPTRefactBlock(nn.Module):
250
  def __init__(self, config, layer_idx=None):
251
  super().__init__()
 
259
  self.mlp = MLP(self.inner_dim, config)
260
 
261
  def forward(
262
+ self,
263
+ hidden_states: Optional[Tuple[torch.Tensor]],
264
+ layer_past: Optional[torch.Tensor] = None,
265
+ attention_mask: Optional[torch.Tensor] = None,
266
+ alibi: Optional[torch.Tensor] = None,
267
+ use_cache: Optional[bool] = False,
268
+ output_attentions: Optional[bool] = False,
 
 
269
  ) -> Union[
270
  Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
271
  ]:
 
297
 
298
 
299
  class GPTRefactPreTrainedModel(PreTrainedModel):
 
 
 
 
 
300
  config_class = GPTRefactConfig
301
  base_model_prefix = "transformer"
302
  supports_gradient_checkpointing = True
 
307
  super().__init__(*inputs, **kwargs)
308
 
309
  def _init_weights(self, module):
 
310
  if isinstance(module, (MLP, Attention)):
311
  # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
312
  # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 
328
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
329
  if module.padding_idx is not None:
330
  module.weight.data[module.padding_idx].zero_()
331
+ elif isinstance(module, LayerNormNoBias):
 
332
  module.weight.data.fill_(1.0)
333
 
334
  def _set_gradient_checkpointing(self, module, value=False):
 
367
  return mask
368
 
369
  def forward(
370
+ self,
371
+ input_ids: Optional[torch.Tensor] = None,
372
+ past_key_values: Optional[List[torch.Tensor]] = None,
373
+ attention_mask: Optional[torch.Tensor] = None,
374
+ inputs_embeds: Optional[torch.Tensor] = None,
375
+ use_cache: Optional[bool] = None,
376
+ output_attentions: Optional[bool] = None,
377
+ output_hidden_states: Optional[bool] = None,
378
+ return_dict: Optional[bool] = None,
 
 
 
 
 
379
  ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
380
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
381
  output_hidden_states = (
 
401
 
402
  device = input_ids.device if input_ids is not None else inputs_embeds.device
403
 
 
 
 
 
 
404
  if past_key_values is None:
405
  past_length = 0
406
  past_key_values = tuple([None] * len(self.h))
407
  else:
408
  past_length = past_key_values[0][0].size(-2)
409
 
 
 
 
 
 
 
 
 
 
 
410
  # Self-attention mask.
411
  query_length = input_shape[-1]
412
 
 
421
  alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
422
  self.num_heads, device, self.wte.weight.dtype)[:, :, -query_length:, :]
423
 
 
 
 
 
424
  output_shape = input_shape + (hidden_states.size(-1),)
425
 
426
  presents = [] if use_cache else None
 
445
  hidden_states,
446
  None,
447
  attention_mask,
448
+ alibi
 
 
449
  )
450
  else:
451
  outputs = block(
 
453
  layer_past=layer_past,
454
  attention_mask=attention_mask,
455
  alibi=alibi,
 
 
456
  use_cache=use_cache,
457
  output_attentions=output_attentions,
458
  )
 
486
  cross_attentions=all_cross_attentions,
487
  )
488
 
489
+
490
  class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
491
+ _tied_weights_keys = ["lm_head.weight", "ln_f.weight"]
492
 
493
  def __init__(self, config):
494
  super().__init__(config)
495
  self.transformer = GPTRefactModel(config)
496
+ self.ln_f = LayerNormNoBias(self.transformer.embed_dim, eps=config.layer_norm_epsilon)
497
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
498
 
499
  # Initialize weights and apply final processing
500
  self.post_init()
501
 
502
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
 
 
503
  if inputs_embeds is not None and past_key_values is None:
504
  model_inputs = {"inputs_embeds": inputs_embeds}
505
  else:
 
517
  return model_inputs
518
 
519
  def forward(
520
+ self,
521
+ input_ids: Optional[torch.Tensor] = None,
522
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
523
+ attention_mask: Optional[torch.Tensor] = None,
524
+ inputs_embeds: Optional[torch.Tensor] = None,
525
+ labels: Optional[torch.Tensor] = None,
526
+ use_cache: Optional[bool] = None,
527
+ output_attentions: Optional[bool] = None,
528
+ output_hidden_states: Optional[bool] = None,
529
+ return_dict: Optional[bool] = None,
 
 
 
 
 
530
  ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
531
  r"""
532
  labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
 
540
  input_ids,
541
  past_key_values=past_key_values,
542
  attention_mask=attention_mask,
 
 
 
543
  inputs_embeds=inputs_embeds,
 
 
544
  use_cache=use_cache,
545
  output_attentions=output_attentions,
546
  output_hidden_states=output_hidden_states,
 
575
 
576
  @staticmethod
577
  def _reorder_cache(
578
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
579
  ) -> Tuple[Tuple[torch.Tensor]]:
580
  """
581
  This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:90c3d5d52d19ceeec70bb4c53edfe2d98eb1acb710c62446e426391adb32301c
3
- size 6343470101
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58d077cf9a7cf9fa4589e3adb03603dab48a47af9e9a9bf084add65fe7574811
3
+ size 6343461637