kmfoda commited on
Commit
c5237a9
·
1 Parent(s): 3614757

Upload config and modelling files

Browse files
config.json CHANGED
@@ -1,271 +1,25 @@
1
  {
2
- "_name_or_path": "distributed/optimized-gpt2-2b",
3
  "activation_function": "gelu_new",
4
  "all_reduce_scores": {
5
  "0": "SUCCESS",
6
  "1": "SUCCESS",
7
- "10": "SUCCESS",
8
- "100": "SUCCESS",
9
- "101": "SUCCESS",
10
- "102": "SUCCESS",
11
- "103": "SUCCESS",
12
- "104": "SUCCESS",
13
- "105": "SUCCESS",
14
- "106": "SUCCESS",
15
- "107": "SUCCESS",
16
- "108": "SUCCESS",
17
- "109": "SUCCESS",
18
- "11": "SUCCESS",
19
- "110": "SUCCESS",
20
- "111": "SUCCESS",
21
- "112": "SUCCESS",
22
- "113": "SUCCESS",
23
- "114": "SUCCESS",
24
- "115": "SUCCESS",
25
- "116": "SUCCESS",
26
- "117": "SUCCESS",
27
- "118": "SUCCESS",
28
- "119": "SUCCESS",
29
- "12": "SUCCESS",
30
- "120": "SUCCESS",
31
- "121": "SUCCESS",
32
- "122": "SUCCESS",
33
- "123": "SUCCESS",
34
- "124": "SUCCESS",
35
- "125": "SUCCESS",
36
- "126": "SUCCESS",
37
- "127": "SUCCESS",
38
- "128": "SUCCESS",
39
- "129": "SUCCESS",
40
- "13": "SUCCESS",
41
- "130": "SUCCESS",
42
- "131": "SUCCESS",
43
- "132": "SUCCESS",
44
- "133": "SUCCESS",
45
- "134": "SUCCESS",
46
- "135": "SUCCESS",
47
- "136": "SUCCESS",
48
- "137": "SUCCESS",
49
- "138": "SUCCESS",
50
- "139": "SUCCESS",
51
- "14": "SUCCESS",
52
- "140": "SUCCESS",
53
- "141": "SUCCESS",
54
- "142": "SUCCESS",
55
- "143": "SUCCESS",
56
- "144": "SUCCESS",
57
- "145": "SUCCESS",
58
- "146": "SUCCESS",
59
- "147": "SUCCESS",
60
- "148": "SUCCESS",
61
- "149": "SUCCESS",
62
- "15": "SUCCESS",
63
- "150": "SUCCESS",
64
- "151": "SUCCESS",
65
- "152": "SUCCESS",
66
- "153": "SUCCESS",
67
- "154": "SUCCESS",
68
- "155": "SUCCESS",
69
- "156": "SUCCESS",
70
- "157": "SUCCESS",
71
- "158": "SUCCESS",
72
- "159": "SUCCESS",
73
- "16": "SUCCESS",
74
- "160": "SUCCESS",
75
- "161": "SUCCESS",
76
- "162": "SUCCESS",
77
- "163": "SUCCESS",
78
- "164": "SUCCESS",
79
- "165": "SUCCESS",
80
- "166": "SUCCESS",
81
- "167": "SUCCESS",
82
- "168": "SUCCESS",
83
- "169": "SUCCESS",
84
- "17": "SUCCESS",
85
- "170": "SUCCESS",
86
- "171": "SUCCESS",
87
- "172": "SUCCESS",
88
- "173": "SUCCESS",
89
- "174": "SUCCESS",
90
- "175": "SUCCESS",
91
- "176": "SUCCESS",
92
- "177": "SUCCESS",
93
- "178": "SUCCESS",
94
- "179": "SUCCESS",
95
- "18": "SUCCESS",
96
- "180": "SUCCESS",
97
- "181": "SUCCESS",
98
- "182": "SUCCESS",
99
- "183": "SUCCESS",
100
- "184": "SUCCESS",
101
- "185": "SUCCESS",
102
- "186": "SUCCESS",
103
- "187": "SUCCESS",
104
- "188": "SUCCESS",
105
- "189": "SUCCESS",
106
- "19": "SUCCESS",
107
- "190": "SUCCESS",
108
- "191": "SUCCESS",
109
- "192": "SUCCESS",
110
- "193": "SUCCESS",
111
- "194": "SUCCESS",
112
- "195": "SUCCESS",
113
- "196": "SUCCESS",
114
- "197": "SUCCESS",
115
- "198": "SUCCESS",
116
- "199": "SUCCESS",
117
  "2": "SUCCESS",
118
- "20": "SUCCESS",
119
- "200": "SUCCESS",
120
- "201": "SUCCESS",
121
- "202": "SUCCESS",
122
- "203": "SUCCESS",
123
- "204": "SUCCESS",
124
- "205": "SUCCESS",
125
- "206": "SUCCESS",
126
- "207": "SUCCESS",
127
- "208": "SUCCESS",
128
- "209": "SUCCESS",
129
- "21": "SUCCESS",
130
- "210": "SUCCESS",
131
- "211": "SUCCESS",
132
- "212": "SUCCESS",
133
- "213": "SUCCESS",
134
- "214": "SUCCESS",
135
- "215": "SUCCESS",
136
- "216": "SUCCESS",
137
- "217": "SUCCESS",
138
- "218": "SUCCESS",
139
- "219": "SUCCESS",
140
- "22": "SUCCESS",
141
- "220": "SUCCESS",
142
- "221": "SUCCESS",
143
- "222": "SUCCESS",
144
- "223": "SUCCESS",
145
- "224": "SUCCESS",
146
- "225": "SUCCESS",
147
- "226": "SUCCESS",
148
- "227": "SUCCESS",
149
- "228": "SUCCESS",
150
- "229": "SUCCESS",
151
- "23": "SUCCESS",
152
- "230": "SUCCESS",
153
- "231": "SUCCESS",
154
- "232": "SUCCESS",
155
- "233": "SUCCESS",
156
- "234": "SUCCESS",
157
- "235": "SUCCESS",
158
- "236": "SUCCESS",
159
- "237": "SUCCESS",
160
- "238": "SUCCESS",
161
- "239": "SUCCESS",
162
- "24": "SUCCESS",
163
- "240": "SUCCESS",
164
- "241": "SUCCESS",
165
- "242": "SUCCESS",
166
- "243": "SUCCESS",
167
- "244": "SUCCESS",
168
- "245": "SUCCESS",
169
- "246": "SUCCESS",
170
- "247": "SUCCESS",
171
- "248": "SUCCESS",
172
- "249": "SUCCESS",
173
- "25": "SUCCESS",
174
- "250": "SUCCESS",
175
- "251": "SUCCESS",
176
- "252": "SUCCESS",
177
- "253": "SUCCESS",
178
- "254": "SUCCESS",
179
- "255": "SUCCESS",
180
- "26": "SUCCESS",
181
- "27": "SUCCESS",
182
- "28": "SUCCESS",
183
- "29": "SUCCESS",
184
- "3": "SUCCESS",
185
- "30": "SUCCESS",
186
- "31": "SUCCESS",
187
- "32": "SUCCESS",
188
- "33": "SUCCESS",
189
- "34": "SUCCESS",
190
- "35": "SUCCESS",
191
- "36": "SUCCESS",
192
- "37": "SUCCESS",
193
- "38": "SUCCESS",
194
- "39": "SUCCESS",
195
- "4": "SUCCESS",
196
- "40": "SUCCESS",
197
- "41": "SUCCESS",
198
- "42": "SUCCESS",
199
- "43": "SUCCESS",
200
- "44": "SUCCESS",
201
- "45": "SUCCESS",
202
- "46": "SUCCESS",
203
- "47": "SUCCESS",
204
- "48": "SUCCESS",
205
- "49": "SUCCESS",
206
- "5": "SUCCESS",
207
- "50": "SUCCESS",
208
- "51": "SUCCESS",
209
- "52": "SUCCESS",
210
- "53": "SUCCESS",
211
- "54": "SUCCESS",
212
- "55": "SUCCESS",
213
- "56": "SUCCESS",
214
- "57": "SUCCESS",
215
- "58": "SUCCESS",
216
- "59": "SUCCESS",
217
- "6": "SUCCESS",
218
- "60": "SUCCESS",
219
- "61": "SUCCESS",
220
- "62": "SUCCESS",
221
- "63": "SUCCESS",
222
- "64": "SUCCESS",
223
- "65": "SUCCESS",
224
- "66": "SUCCESS",
225
- "67": "SUCCESS",
226
- "68": "SUCCESS",
227
- "69": "SUCCESS",
228
- "7": "SUCCESS",
229
- "70": "SUCCESS",
230
- "71": "SUCCESS",
231
- "72": "SUCCESS",
232
- "73": "SUCCESS",
233
- "74": "SUCCESS",
234
- "75": "SUCCESS",
235
- "76": "SUCCESS",
236
- "77": "SUCCESS",
237
- "78": "SUCCESS",
238
- "79": "SUCCESS",
239
- "8": "SUCCESS",
240
- "80": "SUCCESS",
241
- "81": "SUCCESS",
242
- "82": "SUCCESS",
243
- "83": "SUCCESS",
244
- "84": "SUCCESS",
245
- "85": "SUCCESS",
246
- "86": "SUCCESS",
247
- "87": "SUCCESS",
248
- "88": "SUCCESS",
249
- "89": "SUCCESS",
250
- "9": "SUCCESS",
251
- "90": "SUCCESS",
252
- "91": "SUCCESS",
253
- "92": "SUCCESS",
254
- "93": "SUCCESS",
255
- "94": "SUCCESS",
256
- "95": "SUCCESS",
257
- "96": "SUCCESS",
258
- "97": "SUCCESS",
259
- "98": "SUCCESS",
260
- "99": "SUCCESS"
261
  },
262
  "architectures": [
263
  "GPTOptim"
264
  ],
265
  "attn_pdrop": 0.1,
266
  "auto_map": {
267
- "AutoConfig": "distributed/optimized-gpt2-500m--configuration_gpt_optimized.GPTOptimConfig",
268
- "AutoModelForCausalLM": "distributed/optimized-gpt2-500m--modeling_gpt_optimized.GPTOptim"
269
  },
270
  "block_size": 1024,
271
  "bos_token_id": 50256,
 
1
  {
2
+ "_name_or_path": "distributed/optimized-gpt2-2b-vtestnet-v1",
3
  "activation_function": "gelu_new",
4
  "all_reduce_scores": {
5
  "0": "SUCCESS",
6
  "1": "SUCCESS",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  "2": "SUCCESS",
8
+ "3": "NON_PARTICIPATING",
9
+ "4": "NON_PARTICIPATING",
10
+ "5": "NON_PARTICIPATING",
11
+ "6": "NON_PARTICIPATING",
12
+ "7": "NON_PARTICIPATING",
13
+ "8": "NON_PARTICIPATING",
14
+ "9": "NON_PARTICIPATING"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  },
16
  "architectures": [
17
  "GPTOptim"
18
  ],
19
  "attn_pdrop": 0.1,
20
  "auto_map": {
21
+ "AutoConfig": "configuration_gpt_optimized.GPTOptimConfig",
22
+ "AutoModelForCausalLM": "modeling_gpt_optimized.GPTOptim"
23
  },
24
  "block_size": 1024,
25
  "bos_token_id": 50256,
configuration_gpt_optimized.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, GPT2Config
2
+ from typing import List
3
+
4
+
5
+ class GPTOptimConfig(GPT2Config):
6
+ model_type = "gpt_optimized"
7
+
8
+ def __init__(
9
+ self,
10
+ block_size: int = 1024, # max sequence length
11
+ vocab_size: int = 50257, # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
12
+ n_layer: int = 16, # number of layers
13
+ n_head: int = 16, # number of heads
14
+ n_embd: int = 1024, # embedding dimension
15
+ **kwargs,
16
+ ):
17
+ super().__init__(**kwargs)
18
+ self.block_size = block_size
19
+ self.vocab_size = vocab_size
20
+ self.n_layer = n_layer
21
+ self.n_head = n_head
22
+ self.n_embd = n_embd
modeling_gpt_optimized.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import bitsandbytes
4
+ from torch.nn import CrossEntropyLoss, functional as F
5
+ from transformers import PreTrainedModel, GPT2PreTrainedModel
6
+ from .configuration_gpt_optimized import GPTOptimConfig
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
8
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
9
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
10
+ from typing import Optional, Tuple, Union
11
+
12
+ _CHECKPOINT_FOR_DOC = "openai-community/gpt2"
13
+ _CONFIG_FOR_DOC = "GPT2Config"
14
+
15
+ GPT2_INPUTS_DOCSTRING = r"""
16
+ Args:
17
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
18
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
19
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
20
+ sequence tokens in the vocabulary.
21
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
22
+ `input_ids`.
23
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
24
+ [`PreTrainedTokenizer.__call__`] for details.
25
+ [What are input IDs?](../glossary#input-ids)
26
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
27
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
28
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
29
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
30
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
31
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
32
+ - 1 for tokens that are **not masked**,
33
+ - 0 for tokens that are **masked**.
34
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
35
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
36
+ `len(past_key_values) + len(input_ids)`
37
+ [What are attention masks?](../glossary#attention-mask)
38
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
39
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
40
+ 1]`:
41
+ - 0 corresponds to a *sentence A* token,
42
+ - 1 corresponds to a *sentence B* token.
43
+ [What are token type IDs?](../glossary#token-type-ids)
44
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
46
+ config.max_position_embeddings - 1]`.
47
+ [What are position IDs?](../glossary#position-ids)
48
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
49
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
50
+ - 1 indicates the head is **not masked**,
51
+ - 0 indicates the head is **masked**.
52
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
53
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
54
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
55
+ model's internal embedding lookup matrix.
56
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
57
+ `past_key_values`).
58
+ use_cache (`bool`, *optional*):
59
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
60
+ `past_key_values`).
61
+ output_attentions (`bool`, *optional*):
62
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
63
+ tensors for more detail.
64
+ output_hidden_states (`bool`, *optional*):
65
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
66
+ more detail.
67
+ return_dict (`bool`, *optional*):
68
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
69
+ """
70
+
71
+ class CausalSelfAttention(nn.Module):
72
+
73
+ def __init__(self, config):
74
+ super().__init__()
75
+ assert config.n_embd % config.n_head == 0
76
+ # key, query, value projections for all heads, but in a batch
77
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
78
+ # output projection
79
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
80
+ self.c_proj.NANOGPT_SCALE_INIT = 1
81
+ # regularization
82
+ self.n_head = config.n_head
83
+ self.n_embd = config.n_embd
84
+
85
+ def forward(self, x):
86
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
87
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
88
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
89
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
90
+ qkv = self.c_attn(x)
91
+ q, k, v = qkv.split(self.n_embd, dim=2)
92
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
93
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
94
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
95
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
96
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
97
+ # output projection
98
+ y = self.c_proj(y)
99
+ return y
100
+
101
+ class MLP(nn.Module):
102
+
103
+ def __init__(self, config):
104
+ super().__init__()
105
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
106
+ self.gelu = nn.GELU(approximate='tanh')
107
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
108
+ self.c_proj.NANOGPT_SCALE_INIT = 1
109
+
110
+ def forward(self, x):
111
+ x = self.c_fc(x)
112
+ x = self.gelu(x)
113
+ x = self.c_proj(x)
114
+ return x
115
+
116
+ class Block(nn.Module):
117
+
118
+ def __init__(self, config):
119
+ super().__init__()
120
+ self.ln_1 = nn.LayerNorm(config.n_embd)
121
+ self.attn = CausalSelfAttention(config)
122
+ self.ln_2 = nn.LayerNorm(config.n_embd)
123
+ self.mlp = MLP(config)
124
+
125
+ def forward(self, x):
126
+ x = x + self.attn(self.ln_1(x))
127
+ x = x + self.mlp(self.ln_2(x))
128
+ return x
129
+
130
+ class GPT(nn.Module):
131
+
132
+ def __init__(self, config):
133
+ super().__init__()
134
+ self.config = config
135
+
136
+ self.transformer = nn.ModuleDict(dict(
137
+ wte = bitsandbytes.nn.StableEmbedding(config.vocab_size, config.n_embd),
138
+ wpe = bitsandbytes.nn.StableEmbedding(config.block_size, config.n_embd),
139
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
140
+ ln_f = nn.LayerNorm(config.n_embd),
141
+ ))
142
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
143
+
144
+ # weight sharing scheme
145
+ self.transformer.wte.weight = self.lm_head.weight
146
+
147
+ # init params
148
+ self.apply(self._init_weights)
149
+
150
+ def _init_weights(self, module):
151
+ if isinstance(module, nn.Linear):
152
+ std = 0.02
153
+ if hasattr(module, 'NANOGPT_SCALE_INIT'):
154
+ std *= (2 * self.config.n_layer) ** -0.5
155
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
156
+ if module.bias is not None:
157
+ torch.nn.init.zeros_(module.bias)
158
+ elif isinstance(module, nn.Embedding):
159
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
160
+
161
+ class GPTOptim(GPT2PreTrainedModel):
162
+ config_class = GPTOptimConfig
163
+
164
+ def __init__(self, config):
165
+ super().__init__(config)
166
+ self.model = GPT(
167
+ config
168
+ )
169
+ self.config = config
170
+
171
+ def forward(self, input_ids, labels=None):
172
+ # input_ids is of shape (B, T)
173
+ B, T = input_ids.size()
174
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
175
+ # forward the token and posisition embeddings
176
+ pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device) # shape (T)
177
+ pos_emb = self.model.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
178
+ tok_emb = self.model.transformer.wte(input_ids) # token embeddings of shape (B, T, n_embd)
179
+ x = tok_emb + pos_emb
180
+ # forward the blocks of the transformer
181
+ for block in self.model.transformer.h:
182
+ x = block(x)
183
+ # forward the final layernorm and the classifier
184
+ x = self.model.transformer.ln_f(x)
185
+ logits = self.model.lm_head(x) # (B, T, vocab_size)
186
+ loss = None
187
+ if labels is not None:
188
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=self.config.eos_token_id)
189
+ return logits, loss