BeardedMonster commited on
Commit
2e195a2
·
verified ·
1 Parent(s): 7eb7f81

Upload GPTJXForCausalLM

Browse files
Files changed (1) hide show
  1. pretrained_model.py +4 -61
pretrained_model.py CHANGED
@@ -173,9 +173,6 @@ class GPTJXForCausalLM(PreTrainedModel):
173
  device = idx.device
174
  b, t = idx.size()
175
 
176
- # attn_mask = _prepare_mask_(idx, b, eval)
177
- # print("attention mask: ", attn_mask)
178
-
179
  assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
180
  pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
181
 
@@ -186,17 +183,16 @@ class GPTJXForCausalLM(PreTrainedModel):
186
  for block in self.transformer.h:
187
  x = block(x, attn_mask=attn_mask)
188
  x = self.transformer.ln_f(x)
 
 
189
 
190
  if targets is not None:
191
- # if we are given some desired targets also calculate the loss
192
- logits = self.lm_head(x)
193
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
194
  else:
195
- # inference-time mini-optimization: only forward the lm_head on the very last position
196
- logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
197
  loss = None
198
 
199
- # return {"logits": logits, "loss": loss}
200
  return CausalLMOutputWithPast(
201
  loss=loss,
202
  logits=logits,
@@ -213,38 +209,6 @@ class GPTJXForCausalLM(PreTrainedModel):
213
  model_inputs["attn_mask"] = attention_mask
214
 
215
  return model_inputs
216
-
217
-
218
- # @torch.no_grad()
219
- # def stream(self, idx, max_new_tokens, temperature=1.0, top_k=None,gen_mode="greedy"):
220
- # """
221
- # Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
222
- # the sequence max_new_tokens times, feeding the predictions back into the model each time.
223
- # Most likely you'll want to make sure to be in model.eval() mode of operation for this.
224
- # """
225
- # for _ in range(max_new_tokens):
226
- # # if the sequence context is growing too long we must crop it at block_size
227
- # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
228
- # # forward the model to get the logits for the index in the sequence
229
- # logits, _ = self(idx_cond, eval=True)
230
- # # pluck the logits at the final step and scale by desired temperature
231
- # logits = logits[:, -1, :] / temperature
232
- # # optionally crop the logits to only the top k options
233
- # if top_k is not None:
234
- # v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
235
- # logits[logits < v[:, [-1]]] = -float('Inf')
236
- # # apply softmax to convert logits to (normalized) probabilities
237
- # probs = F.softmax(logits, dim=-1)
238
- # # sample from the distribution
239
- # if gen_mode == 'greedy':
240
- # idx_next = torch.argmax(probs, dim=-1).unsqueeze(0)
241
-
242
- # else:
243
- # idx_next = torch.multinomial(probs, num_samples=1)
244
- # # print(idx_next.shape, idx.shape)
245
- # idx = torch.cat((idx, idx_next), dim=1)
246
- # # append sampled index to the running sequence and continue
247
- # yield idx_next
248
 
249
 
250
  def crop_block_size(self, block_size):
@@ -263,24 +227,3 @@ AutoConfig.register("nanogpt-j", GPTJXConfig)
263
  AutoModel.register(GPTJXConfig,GPTJXForCausalLM)
264
  AutoModelForCausalLM.register(GPTJXConfig, GPTJXForCausalLM)
265
 
266
-
267
- # if __name__ == '__main__':
268
- # from transformers import AutoTokenizer
269
-
270
- # tokenizer = AutoTokenizer.from_pretrained("BeardedMonster/SabiYarn")
271
- # input_ids = tokenizer("Awọn eeyan Cairo, ni Egypt ti bẹrẹ si n to lawọn ileesẹ to n ṣe burẹdi bayii.", return_tensors="pt")["input_ids"]
272
- # targets = input_ids
273
-
274
- # # config = GPTJConfig()
275
- # # config.save_pretrained("gptj-config")
276
- # # new_config = GPTJ.from_pretrained("gptj-config")
277
- # # model = GPTJ(config)
278
- # # state_dict = torch.load('model.pt', map_location="cpu")
279
- # # model.load_state_dict(state_dict)
280
- # model = GPTJXForCausalLM.from_pretrained("/pretrainedmodel")
281
- # # model.save_pretrained("/pretrainedmodel")
282
- # # outputs = model(input_ids, targets)
283
- # # print(outputs)
284
- # output = model.generate(input_ids, max_new_tokens=50)
285
- # print(tokenizer.decode(output[0]))
286
- # print(new_config)
 
173
  device = idx.device
174
  b, t = idx.size()
175
 
 
 
 
176
  assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
177
  pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
178
 
 
183
  for block in self.transformer.h:
184
  x = block(x, attn_mask=attn_mask)
185
  x = self.transformer.ln_f(x)
186
+
187
+ logits = self.lm_head(x) # logits over the entire sequence, shape (b, t, vocab_size)
188
 
189
  if targets is not None:
190
+ # If targets are provided, compute the loss
 
191
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
192
  else:
193
+ # Inference-time: return logits for each timestep
 
194
  loss = None
195
 
 
196
  return CausalLMOutputWithPast(
197
  loss=loss,
198
  logits=logits,
 
209
  model_inputs["attn_mask"] = attention_mask
210
 
211
  return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
 
214
  def crop_block_size(self, block_size):
 
227
  AutoModel.register(GPTJXConfig,GPTJXForCausalLM)
228
  AutoModelForCausalLM.register(GPTJXConfig, GPTJXForCausalLM)
229