reymondzzzz commited on
Commit
1799ef3
1 Parent(s): c4f295d

Upload GPTRefactForCausalLM

Browse files
config.json CHANGED
@@ -5,6 +5,10 @@
5
  ],
6
  "attention_softmax_in_fp32": true,
7
  "attn_pdrop": 0.1,
 
 
 
 
8
  "bos_token_id": 50256,
9
  "do_sample": true,
10
  "embd_pdrop": 0.1,
 
5
  ],
6
  "attention_softmax_in_fp32": true,
7
  "attn_pdrop": 0.1,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_gpt_refact.GPTRefactConfig",
10
+ "AutoModelForCausalLM": "modeling_gpt_refact.GPTRefactForCausalLM"
11
+ },
12
  "bos_token_id": 50256,
13
  "do_sample": true,
14
  "embd_pdrop": 0.1,
configuration_gpt_refact.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+
8
+ class GPTRefactConfig(PretrainedConfig):
9
+ model_type = "gpt_refact"
10
+ keys_to_ignore_at_inference = ["past_key_values"]
11
+ attribute_map = {
12
+ "hidden_size": "n_embd",
13
+ "max_position_embeddings": "n_positions",
14
+ "num_attention_heads": "n_head",
15
+ "num_hidden_layers": "n_layer",
16
+ }
17
+
18
+ def __init__(
19
+ self,
20
+ vocab_size=50257,
21
+ n_positions=1024,
22
+ n_embd=768,
23
+ n_layer=12,
24
+ n_head=12,
25
+ n_inner=None,
26
+ activation_function="gelu_new",
27
+ resid_pdrop=0.1,
28
+ embd_pdrop=0.1,
29
+ attn_pdrop=0.1,
30
+ layer_norm_epsilon=1e-5,
31
+ initializer_range=0.02,
32
+ scale_attn_weights=True,
33
+ use_cache=True,
34
+ bos_token_id=50256,
35
+ eos_token_id=50256,
36
+ attention_softmax_in_fp32=True,
37
+ scale_attention_softmax_in_fp32=True,
38
+ multi_query=True,
39
+ **kwargs,
40
+ ):
41
+ self.vocab_size = vocab_size
42
+ self.n_positions = n_positions
43
+ self.n_embd = n_embd
44
+ self.n_layer = n_layer
45
+ self.n_head = n_head
46
+ self.n_inner = n_inner
47
+ self.activation_function = activation_function
48
+ self.resid_pdrop = resid_pdrop
49
+ self.embd_pdrop = embd_pdrop
50
+ self.attn_pdrop = attn_pdrop
51
+ self.layer_norm_epsilon = layer_norm_epsilon
52
+ self.initializer_range = initializer_range
53
+ self.scale_attn_weights = scale_attn_weights
54
+ self.use_cache = use_cache
55
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
56
+ self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
57
+ self.multi_query = multi_query
58
+
59
+ self.bos_token_id = bos_token_id
60
+ self.eos_token_id = eos_token_id
61
+
62
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
modeling_gpt_refact.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
+
10
+ from transformers.activations import ACT2FN
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutputWithPastAndCrossAttentions,
13
+ CausalLMOutputWithCrossAttentions,
14
+ SequenceClassifierOutputWithPast,
15
+ TokenClassifierOutput,
16
+ )
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import (
19
+ add_code_sample_docstrings,
20
+ add_start_docstrings,
21
+ add_start_docstrings_to_model_forward,
22
+ logging,
23
+ )
24
+ from hf.configuration_gpt_refact import GPTRefactConfig
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ # Fused kernels
31
+ # Use separate functions for each case because conditionals prevent kernel fusion.
32
+ # TODO: Could have better fused kernels depending on scaling, dropout and head mask.
33
+ # Is it doable without writing 32 functions?
34
+ @torch.jit.script
35
+ def upcast_masked_softmax(
36
+ x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
37
+ ):
38
+ input_dtype = x.dtype
39
+ x = x.to(softmax_dtype) * scale
40
+ x = torch.where(mask, x, mask_value)
41
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
42
+ return x
43
+
44
+
45
+ @torch.jit.script
46
+ def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
47
+ input_dtype = x.dtype
48
+ x = x.to(softmax_dtype) * scale
49
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
50
+ return x
51
+
52
+
53
+ @torch.jit.script
54
+ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
55
+ x = torch.where(mask, x, mask_value)
56
+ x = torch.nn.functional.softmax(x, dim=-1)
57
+ return x
58
+
59
+ def _get_slopes(attn_heads: int, dev: str) -> torch.Tensor:
60
+ """
61
+ ## Get head-specific slope $m$ for each head
62
+ * `n_heads` is the number of heads in the attention layer $n$
63
+ The slope for first head is
64
+ $$\frac{1}{2^{\frac{8}{n}}} = 2^{-\frac{8}{n}}$$
65
+ The slopes for the rest of the heads are in a geometric series with a ratio same as above.
66
+ For instance when the number of heads is $8$ the slopes are
67
+ $$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$
68
+ """
69
+
70
+ # Get the closest power of 2 to `n_heads`.
71
+ # If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2,
72
+ # and then add the remaining slopes.
73
+ n = 2 ** math.floor(math.log2(attn_heads))
74
+ # $2^{-\frac{8}{n}}$
75
+ m_0 = 2.0 ** (-8.0 / n)
76
+ # $2^{-1\frac{8}{n}}, 2^{-2 \frac{8}{n}}, 2^{-3 \frac{8}{n}}, \dots$
77
+ m = torch.pow(m_0, torch.arange(1, 1 + n, device=dev))
78
+
79
+ # If `n_heads` is not a power of 2, then we add the remaining slopes.
80
+ # We calculate the remaining slopes for $n * 2$ (avoiding slopes added previously).
81
+ # And pick the slopes upto `n_heads`.
82
+ if n < attn_heads:
83
+ # $2^{-\frac{8}{2n}}$
84
+ m_hat_0 = 2.0 ** (-4.0 / n)
85
+ # $2^{-1\frac{8}{2n}}, 2^{-3 \frac{8}{2n}}, 2^{-5 \frac{8}{2n}}, \dots$
86
+ # Note that we take steps by $2$ to avoid slopes added previously.
87
+ m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (attn_heads - n), 2, device=dev))
88
+ # Concatenate the slopes with the remaining slopes.
89
+ m = torch.cat([m, m_hat])
90
+
91
+ return m
92
+
93
+
94
+ def get_alibi_biases(
95
+ B: int,
96
+ T: int,
97
+ attn_heads: int,
98
+ dev: str,
99
+ dtype,
100
+ causal: bool = True) -> torch.Tensor:
101
+ """
102
+ ## Calculate the attention biases matrix
103
+ * `n_heads` is the number of heads in the attention layer
104
+ * `mask` is the attention mask of shape `[seq_len_q, seq_len_k]`
105
+ This returns a matrix of shape `[seq_len_q, seq_len_k, n_heads, ]` with ALiBi attention biases.
106
+ """
107
+
108
+ # Get slopes $m$ for each head
109
+ if causal:
110
+ mask = (torch.triu(torch.ones((T, T), device=dev)) == 1).transpose(0, 1)
111
+ else:
112
+ mask = torch.ones((T, T), device=dev, dtype=torch.bool)
113
+
114
+ m = _get_slopes(attn_heads, dev)
115
+
116
+ # Calculate distances $[0, 1, \dots, N]$
117
+ # Here we calculate the distances using the mask.
118
+ #
119
+ # Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
120
+ # `distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]`
121
+ distance = mask.cumsum(dim=-1)
122
+
123
+ # Multiply them pair-wise to get the AliBi bias matrix
124
+ biases = distance[:, :, None] * m[None, None, :]
125
+ biases = biases.permute(2, 0, 1)[None, :, :T, :T]
126
+ biases = biases.repeat(B, 1, 1, 1)
127
+ return biases.to(dtype).contiguous()
128
+
129
+ class Attention(nn.Module):
130
+ def __init__(self, config, layer_idx=None):
131
+ super().__init__()
132
+ self.mask_value = None
133
+
134
+ self.multi_query = config.multi_query
135
+ self.embed_dim = config.hidden_size
136
+ self.num_heads = config.num_attention_heads
137
+ self.head_dim = self.embed_dim // self.num_heads
138
+ self.kv_attn_heads = 1
139
+
140
+ self.scale = self.head_dim ** -0.5
141
+
142
+ if self.head_dim * self.num_heads != self.embed_dim:
143
+ raise ValueError(
144
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
145
+ f" {self.num_heads})."
146
+ )
147
+
148
+ self.layer_idx = layer_idx
149
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
150
+ self.scale_attention_softmax_in_fp32 = (
151
+ config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
152
+ )
153
+
154
+ self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
155
+ self.k = nn.Linear(self.embed_dim, self.head_dim, bias=False)
156
+ self.v = nn.Linear(self.embed_dim, self.head_dim, bias=False)
157
+ self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
158
+
159
+ def _attn(self, query, key, value, attention_mask=None, alibi=None):
160
+ dtype = query.dtype
161
+ softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
162
+ upcast = dtype != softmax_dtype
163
+ unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
164
+
165
+ # MQA models: (batch_size, query_length, num_heads * head_dim)
166
+ # MHA models: (batch_size, num_heads, query_length, head_dim)
167
+ attn_weights = alibi + torch.matmul(query * self.scale, key)
168
+
169
+ if upcast:
170
+ # Use a fused kernel to prevent a large overhead from casting and scaling.
171
+ # Sub-optimal when the key length is not a multiple of 8.
172
+ if attention_mask is None:
173
+ attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
174
+ else:
175
+ mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
176
+ attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
177
+ else:
178
+ if attention_mask is not None:
179
+
180
+ # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
181
+ attn_weights = torch.masked_fill(attn_weights, attention_mask, -10000)
182
+
183
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
184
+
185
+ attn_output = torch.matmul(attn_weights, value)
186
+
187
+ return attn_output, attn_weights
188
+
189
+ def _split_heads(self, tensor):
190
+ new_shape = tensor.shape[:-1] + (self.num_heads, self.head_dim)
191
+ tensor = tensor.view(new_shape)
192
+ return tensor.permute(0, 2, 1, 3)
193
+
194
+ def forward(
195
+ self,
196
+ hidden_states: torch.Tensor,
197
+ layer_past: Optional[torch.Tensor] = None,
198
+ attention_mask: Optional[torch.Tensor] = None,
199
+ alibi: Optional[torch.Tensor] = None,
200
+ encoder_hidden_states: Optional[torch.Tensor] = None,
201
+ encoder_attention_mask: Optional[torch.Tensor] = None,
202
+ use_cache: Optional[bool] = False,
203
+ output_attentions: Optional[bool] = False,
204
+ ) -> Union[
205
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
206
+ Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
207
+ ]:
208
+ b, t, _ = hidden_states.shape
209
+ query = self.q(hidden_states)
210
+ key = self.k(hidden_states)
211
+ value = self.v(hidden_states)
212
+ query = self._split_heads(query)
213
+ key = key.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3)
214
+ value = value.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3)
215
+
216
+ if layer_past is not None:
217
+ past_key, past_value = layer_past
218
+ key = torch.cat((past_key, key), dim=-2)
219
+ value = torch.cat((past_value, value), dim=-2)
220
+
221
+ if use_cache is True:
222
+ present = (key, value)
223
+ else:
224
+ present = None
225
+
226
+ attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, alibi)
227
+
228
+ attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
229
+ attn_output = self.c_proj(attn_output)
230
+
231
+ outputs = (attn_output, present)
232
+ if output_attentions:
233
+ outputs += (attn_weights,)
234
+
235
+ return outputs # a, present, (attentions)
236
+
237
+
238
+ class MLP(nn.Module):
239
+ def __init__(self, intermediate_size, config, multiple_of: int = 256):
240
+ super().__init__()
241
+ embed_dim = config.hidden_size
242
+ hidden_dim = intermediate_size
243
+ hidden_dim = int(2 * hidden_dim / 3)
244
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
245
+ self.linear_1 = nn.Linear(embed_dim, hidden_dim, bias=False)
246
+ self.linear_3 = nn.Linear(embed_dim, hidden_dim, bias=False)
247
+ self.c_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
248
+
249
+ def forward(self, x: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
250
+ x1 = F.silu(self.linear_1(x))
251
+ x2 = self.linear_3(x)
252
+ x = self.c_proj(x1 * x2)
253
+ return x
254
+
255
+
256
+ class LayerNormNoBias(nn.Module):
257
+
258
+ def __init__(self, shape: int, eps: float = 1e-5):
259
+ super().__init__()
260
+ self.shape = (shape,)
261
+ self.eps = eps
262
+ self.weight = nn.Parameter(torch.empty(self.shape))
263
+
264
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
265
+ return F.layer_norm(x, self.shape, self.weight, None, self.eps)
266
+
267
+ class GPTRefactBlock(nn.Module):
268
+ def __init__(self, config, layer_idx=None):
269
+ super().__init__()
270
+ hidden_size = config.hidden_size
271
+ self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
272
+
273
+ self.ln_1 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
274
+ self.attn = Attention(config, layer_idx=layer_idx)
275
+ self.ln_2 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
276
+
277
+ self.mlp = MLP(self.inner_dim, config)
278
+
279
+ def forward(
280
+ self,
281
+ hidden_states: Optional[Tuple[torch.Tensor]],
282
+ layer_past: Optional[torch.Tensor] = None,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ alibi: Optional[torch.Tensor] = None,
285
+ encoder_hidden_states: Optional[torch.Tensor] = None,
286
+ encoder_attention_mask: Optional[torch.Tensor] = None,
287
+ use_cache: Optional[bool] = False,
288
+ output_attentions: Optional[bool] = False,
289
+ ) -> Union[
290
+ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
291
+ ]:
292
+ hidden_states_norm = self.ln_1(hidden_states)
293
+ attn_outputs = self.attn(
294
+ hidden_states_norm,
295
+ layer_past=layer_past,
296
+ attention_mask=attention_mask,
297
+ alibi=alibi,
298
+ use_cache=use_cache,
299
+ output_attentions=output_attentions,
300
+ )
301
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
302
+ outputs = attn_outputs[1:]
303
+ # residual connection
304
+ mix = attn_output + hidden_states
305
+
306
+ norm_mix = self.ln_2(mix)
307
+ feed_forward_hidden_states = self.mlp(norm_mix)
308
+ # residual connection
309
+ hidden_states = mix + feed_forward_hidden_states
310
+
311
+ if use_cache:
312
+ outputs = (hidden_states,) + outputs
313
+ else:
314
+ outputs = (hidden_states,) + outputs[1:]
315
+
316
+ return outputs # hidden_states, present, (attentions, cross_attentions)
317
+
318
+
319
+ class GPTRefactPreTrainedModel(PreTrainedModel):
320
+ """
321
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
322
+ models.
323
+ """
324
+
325
+ config_class = GPTRefactConfig
326
+ base_model_prefix = "transformer"
327
+ supports_gradient_checkpointing = True
328
+ _no_split_modules = ["GPTRefactBlock"]
329
+ _skip_keys_device_placement = "past_key_values"
330
+
331
+ def __init__(self, *inputs, **kwargs):
332
+ super().__init__(*inputs, **kwargs)
333
+
334
+ def _init_weights(self, module):
335
+ """Initialize the weights."""
336
+ if isinstance(module, (MLP, Attention)):
337
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
338
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
339
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
340
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
341
+ #
342
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
343
+ module.c_proj.weight.data.normal_(
344
+ mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
345
+ )
346
+ module.c_proj._is_hf_initialized = True
347
+ elif isinstance(module, nn.Linear):
348
+ # Slightly different from the TF version which uses truncated_normal for initialization
349
+ # cf https://github.com/pytorch/pytorch/pull/5617
350
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
351
+ if module.bias is not None:
352
+ module.bias.data.zero_()
353
+ elif isinstance(module, nn.Embedding):
354
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
355
+ if module.padding_idx is not None:
356
+ module.weight.data[module.padding_idx].zero_()
357
+ elif isinstance(module, nn.LayerNorm):
358
+ module.bias.data.zero_()
359
+ module.weight.data.fill_(1.0)
360
+
361
+ def _set_gradient_checkpointing(self, module, value=False):
362
+ if isinstance(module, GPTRefactModel):
363
+ module.gradient_checkpointing = value
364
+
365
+
366
+ class GPTRefactModel(GPTRefactPreTrainedModel):
367
+ def __init__(self, config):
368
+ super().__init__(config)
369
+ self.embed_dim = config.hidden_size
370
+ self.num_heads = config.num_attention_heads
371
+
372
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
373
+
374
+ self.h = nn.ModuleList([GPTRefactBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
375
+
376
+ max_positions = config.max_position_embeddings
377
+ self.register_buffer(
378
+ "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False
379
+ )
380
+
381
+ self.gradient_checkpointing = False
382
+
383
+ # Initialize weights and apply final processing
384
+ self.post_init()
385
+
386
+ @staticmethod
387
+ def _make_mask(seq_len: int, past_key_values_length: int):
388
+ # prompt
389
+ if past_key_values_length == 0:
390
+ mask = torch.ones((seq_len, seq_len + past_key_values_length), dtype=torch.bool)
391
+ mask = torch.triu(mask, 1)
392
+ else:
393
+ mask = torch.zeros((seq_len, seq_len + past_key_values_length), dtype=torch.bool)
394
+ return mask
395
+
396
+ def forward(
397
+ self,
398
+ input_ids: Optional[torch.Tensor] = None,
399
+ past_key_values: Optional[List[torch.Tensor]] = None,
400
+ attention_mask: Optional[torch.Tensor] = None,
401
+ token_type_ids: Optional[torch.Tensor] = None,
402
+ position_ids: Optional[torch.Tensor] = None,
403
+ head_mask: Optional[torch.Tensor] = None,
404
+ inputs_embeds: Optional[torch.Tensor] = None,
405
+ encoder_hidden_states: Optional[torch.Tensor] = None,
406
+ encoder_attention_mask: Optional[torch.Tensor] = None,
407
+ use_cache: Optional[bool] = None,
408
+ output_attentions: Optional[bool] = None,
409
+ output_hidden_states: Optional[bool] = None,
410
+ return_dict: Optional[bool] = None,
411
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
412
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
413
+ output_hidden_states = (
414
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
415
+ )
416
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
417
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
418
+
419
+ if input_ids is not None and inputs_embeds is not None:
420
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
421
+ elif input_ids is not None:
422
+ input_shape = input_ids.size()
423
+ input_ids = input_ids.view(-1, input_shape[-1])
424
+ batch_size = input_ids.shape[0]
425
+ elif inputs_embeds is not None:
426
+ input_shape = inputs_embeds.size()[:-1]
427
+ batch_size = inputs_embeds.shape[0]
428
+ else:
429
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
430
+
431
+ if batch_size <= 0:
432
+ raise ValueError("batch_size has to be defined and > 0")
433
+
434
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
435
+
436
+ if token_type_ids is not None:
437
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
438
+ if position_ids is not None:
439
+ position_ids = position_ids.view(-1, input_shape[-1])
440
+
441
+ if past_key_values is None:
442
+ past_length = 0
443
+ past_key_values = tuple([None] * len(self.h))
444
+ else:
445
+ past_length = past_key_values[0][0].size(-2)
446
+
447
+ if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
448
+ # create position_ids on the fly for batch generation
449
+ position_ids = attention_mask.long().cumsum(-1) - 1
450
+ position_ids.masked_fill_(attention_mask == 0, 1)
451
+ if past_length > 0:
452
+ position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
453
+ elif position_ids is None:
454
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
455
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
456
+
457
+ # Self-attention mask.
458
+ query_length = input_shape[-1]
459
+
460
+ seq_length_with_past = past_length + query_length
461
+ if attention_mask is None:
462
+ attention_mask = self._make_mask(query_length, past_length).to(device)
463
+ else:
464
+ attention_mask = attention_mask.to(device)
465
+
466
+ hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds
467
+
468
+ alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
469
+ self.num_heads, device, self.wte.weight.dtype)[:, :, -query_length:, :]
470
+
471
+ if token_type_ids is not None:
472
+ token_type_embeds = self.wte(token_type_ids)
473
+ hidden_states = hidden_states + token_type_embeds
474
+
475
+ output_shape = input_shape + (hidden_states.size(-1),)
476
+
477
+ presents = [] if use_cache else None
478
+ all_self_attentions = () if output_attentions else None
479
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
480
+ all_hidden_states = () if output_hidden_states else None
481
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
482
+ if output_hidden_states:
483
+ all_hidden_states = all_hidden_states + (hidden_states,)
484
+
485
+ if self.gradient_checkpointing and self.training:
486
+
487
+ def create_custom_forward(module):
488
+ def custom_forward(*inputs):
489
+ # None for past_key_value
490
+ return module(*inputs, use_cache, output_attentions)
491
+
492
+ return custom_forward
493
+
494
+ outputs = torch.utils.checkpoint.checkpoint(
495
+ create_custom_forward(block),
496
+ hidden_states,
497
+ None,
498
+ attention_mask,
499
+ head_mask[i],
500
+ encoder_hidden_states,
501
+ encoder_attention_mask,
502
+ )
503
+ else:
504
+ outputs = block(
505
+ hidden_states,
506
+ layer_past=layer_past,
507
+ attention_mask=attention_mask,
508
+ alibi=alibi,
509
+ encoder_hidden_states=encoder_hidden_states,
510
+ encoder_attention_mask=encoder_attention_mask,
511
+ use_cache=use_cache,
512
+ output_attentions=output_attentions,
513
+ )
514
+
515
+ hidden_states = outputs[0]
516
+ if use_cache:
517
+ presents.append(outputs[1])
518
+
519
+ if output_attentions:
520
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
521
+ if self.config.add_cross_attention:
522
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
523
+
524
+ hidden_states = hidden_states.view(output_shape)
525
+ # Add last hidden state
526
+ if output_hidden_states:
527
+ all_hidden_states = all_hidden_states + (hidden_states,)
528
+
529
+ if not return_dict:
530
+ return tuple(
531
+ v
532
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
533
+ if v is not None
534
+ )
535
+
536
+ return BaseModelOutputWithPastAndCrossAttentions(
537
+ last_hidden_state=hidden_states,
538
+ past_key_values=presents,
539
+ hidden_states=all_hidden_states,
540
+ attentions=all_self_attentions,
541
+ cross_attentions=all_cross_attentions,
542
+ )
543
+
544
+ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
545
+ _tied_weights_keys = ["lm_head.weight"]
546
+
547
+ def __init__(self, config):
548
+ super().__init__(config)
549
+ self.transformer = GPTRefactModel(config)
550
+ self.ln_f = nn.LayerNorm(self.transformer.embed_dim, eps=config.layer_norm_epsilon)
551
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
552
+
553
+ # Initialize weights and apply final processing
554
+ self.post_init()
555
+
556
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
557
+
558
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
559
+ if inputs_embeds is not None and past_key_values is None:
560
+ model_inputs = {"inputs_embeds": inputs_embeds}
561
+ else:
562
+ if past_key_values is not None:
563
+ model_inputs = {"input_ids": input_ids[..., -1:]}
564
+ else:
565
+ model_inputs = {"input_ids": input_ids}
566
+
567
+ model_inputs.update(
568
+ {
569
+ "past_key_values": past_key_values,
570
+ "use_cache": kwargs.get("use_cache"),
571
+ }
572
+ )
573
+ return model_inputs
574
+
575
+ def forward(
576
+ self,
577
+ input_ids: Optional[torch.Tensor] = None,
578
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
579
+ attention_mask: Optional[torch.Tensor] = None,
580
+ token_type_ids: Optional[torch.Tensor] = None,
581
+ position_ids: Optional[torch.Tensor] = None,
582
+ head_mask: Optional[torch.Tensor] = None,
583
+ inputs_embeds: Optional[torch.Tensor] = None,
584
+ encoder_hidden_states: Optional[torch.Tensor] = None,
585
+ encoder_attention_mask: Optional[torch.Tensor] = None,
586
+ labels: Optional[torch.Tensor] = None,
587
+ use_cache: Optional[bool] = None,
588
+ output_attentions: Optional[bool] = None,
589
+ output_hidden_states: Optional[bool] = None,
590
+ return_dict: Optional[bool] = None,
591
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
592
+ r"""
593
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
594
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
595
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
596
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
597
+ """
598
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
599
+
600
+ transformer_outputs = self.transformer(
601
+ input_ids,
602
+ past_key_values=past_key_values,
603
+ attention_mask=attention_mask,
604
+ token_type_ids=token_type_ids,
605
+ position_ids=position_ids,
606
+ head_mask=head_mask,
607
+ inputs_embeds=inputs_embeds,
608
+ encoder_hidden_states=encoder_hidden_states,
609
+ encoder_attention_mask=encoder_attention_mask,
610
+ use_cache=use_cache,
611
+ output_attentions=output_attentions,
612
+ output_hidden_states=output_hidden_states,
613
+ return_dict=return_dict,
614
+ )
615
+ hidden_states = transformer_outputs[0]
616
+
617
+ x = self.ln_f(hidden_states)
618
+ lm_logits = self.lm_head(x)
619
+
620
+ loss = None
621
+ if labels is not None:
622
+ # Shift so that tokens < n predict n
623
+ shift_logits = lm_logits[..., :-1, :].contiguous()
624
+ shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
625
+ # Flatten the tokens
626
+ loss_fct = CrossEntropyLoss()
627
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
628
+
629
+ if not return_dict:
630
+ output = (lm_logits,) + transformer_outputs[1:]
631
+ return ((loss,) + output) if loss is not None else output
632
+
633
+ return CausalLMOutputWithCrossAttentions(
634
+ loss=loss,
635
+ logits=lm_logits,
636
+ past_key_values=transformer_outputs.past_key_values,
637
+ hidden_states=transformer_outputs.hidden_states,
638
+ attentions=transformer_outputs.attentions,
639
+ cross_attentions=transformer_outputs.cross_attentions,
640
+ )
641
+
642
+ @staticmethod
643
+ def _reorder_cache(
644
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
645
+ ) -> Tuple[Tuple[torch.Tensor]]:
646
+ """
647
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
648
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
649
+ beam_idx at every generation step.
650
+ """
651
+ return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)