JingzeShi commited on
Commit
b62417a
·
verified ·
1 Parent(s): f31d51b

Upload DogeForCausalLM

Browse files
Files changed (4) hide show
  1. config.json +1 -1
  2. configuration_doge.py +14 -0
  3. generation_config.json +1 -1
  4. modeling_doge.py +255 -257
config.json CHANGED
@@ -41,7 +41,7 @@
41
  },
42
  "rope_theta": 10000.0,
43
  "torch_dtype": "float32",
44
- "transformers_version": "4.47.1",
45
  "use_cache": true,
46
  "vocab_size": 32768
47
  }
 
41
  },
42
  "rope_theta": 10000.0,
43
  "torch_dtype": "float32",
44
+ "transformers_version": "4.48.1",
45
  "use_cache": true,
46
  "vocab_size": 32768
47
  }
configuration_doge.py CHANGED
@@ -127,6 +127,17 @@ class DogeConfig(PretrainedConfig):
127
 
128
  model_type = "doge"
129
  keys_to_ignore_at_inference = ["past_key_values"]
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  def __init__(
132
  self,
@@ -210,3 +221,6 @@ class DogeConfig(PretrainedConfig):
210
  tie_word_embeddings=tie_word_embeddings,
211
  **kwargs,
212
  )
 
 
 
 
127
 
128
  model_type = "doge"
129
  keys_to_ignore_at_inference = ["past_key_values"]
130
+ # Default tensor parallel plan for base model `DogeModel`
131
+ base_model_tp_plan = {
132
+ "layers.*.self_attn.q_proj": "colwise",
133
+ "layers.*.self_attn.k_proj": "colwise",
134
+ "layers.*.self_attn.v_proj": "colwise",
135
+ "layers.*.self_attn.dt_proj": "colwise",
136
+ "layers.*.self_attn.o_proj": "rowwise",
137
+ "layers.*.mlp.gate_proj": "colwise",
138
+ "layers.*.mlp.up_proj": "colwise",
139
+ "layers.*.mlp.down_proj": "rowwise",
140
+ }
141
 
142
  def __init__(
143
  self,
 
221
  tie_word_embeddings=tie_word_embeddings,
222
  **kwargs,
223
  )
224
+
225
+
226
+ __all__ = ["DogeConfig"]
generation_config.json CHANGED
@@ -3,5 +3,5 @@
3
  "bos_token_id": 0,
4
  "eos_token_id": 1,
5
  "pad_token_id": 2,
6
- "transformers_version": "4.47.1"
7
  }
 
3
  "bos_token_id": 0,
4
  "eos_token_id": 1,
5
  "pad_token_id": 2,
6
+ "transformers_version": "4.48.1"
7
  }
modeling_doge.py CHANGED
@@ -19,7 +19,7 @@
19
  """PyTorch Doge model."""
20
 
21
  import math
22
- from typing import List, Optional, Tuple, Union
23
 
24
  import torch
25
  import torch.nn.functional as F
@@ -36,7 +36,9 @@ from transformers.modeling_outputs import (
36
  )
37
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
38
  from transformers.modeling_utils import PreTrainedModel
 
39
  from transformers.utils import (
 
40
  add_start_docstrings,
41
  add_start_docstrings_to_model_forward,
42
  is_torch_greater_or_equal,
@@ -205,51 +207,66 @@ class DogeDynamicMaskAttention(nn.Module):
205
 
206
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
207
  super().__init__()
208
-
209
  self.config = config
210
  self.layer_idx = layer_idx
211
- if layer_idx is None:
212
- logger.warning_once(
213
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. "
214
- "Please make sure to provide a `layer_idx` when creating this class."
215
- )
216
-
217
- self.hidden_dim = config.hidden_size
218
- self.num_heads = config.num_attention_heads
219
- self.head_dim = self.hidden_dim // self.num_heads
220
- self.num_key_value_heads = config.num_key_value_heads
221
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
222
  self.attention_dropout = config.attention_dropout
223
  self.dynamic_mask_ratio = config.dynamic_mask_ratio
224
 
 
 
 
 
 
 
225
  # Q K V O projections
226
- self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=config.hidden_bias)
227
- self.k_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
228
- self.v_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
 
 
 
 
 
 
 
 
 
 
 
 
229
  # dynamic mask for the QK^T attention score matrix
230
- self.A = nn.Parameter(torch.ones(self.num_heads))
231
- self.dt_proj = nn.Linear(self.num_key_value_heads * self.head_dim, self.num_heads, bias=config.hidden_bias)
232
- self.o_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
 
 
 
 
 
 
 
 
 
 
233
 
234
  def forward(
235
  self,
236
  hidden_states: torch.Tensor,
 
237
  attention_mask: Optional[torch.Tensor] = None,
238
- position_ids: Optional[torch.LongTensor] = None,
239
  past_key_value: Optional[Cache] = None,
240
  cache_position: Optional[torch.LongTensor] = None,
241
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
242
  **kwargs,
243
  ) -> Tuple[torch.Tensor, Optional[Cache]]:
244
- bsz, q_len, _ = hidden_states.shape
 
245
 
246
- query_states = self.q_proj(hidden_states)
247
- key_states = self.k_proj(hidden_states)
248
- value_states = self.v_proj(hidden_states)
249
-
250
- query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
251
- key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
252
- value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
253
 
254
  cos, sin = position_embeddings
255
  query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -260,37 +277,32 @@ class DogeDynamicMaskAttention(nn.Module):
260
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
261
 
262
  # calculate dynamic mask from value_states
263
- dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
264
  dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
265
-
266
- # repeat key and value states
267
- key_states = repeat_kv(key_states, self.num_key_value_groups)
268
- value_states = repeat_kv(value_states, self.num_key_value_groups)
269
-
270
- # compute attention scores matrix
271
- attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.head_dim)
272
-
273
- # add mask to attention scores
274
  attn_mask = self.prepare_dynamic_mask(
275
  hidden_states=hidden_states,
276
  dynamic_mask=dynamic_mask,
277
  dynamic_mask_ratio=self.dynamic_mask_ratio,
278
  attention_mask=attention_mask,
279
  )
280
- attn_weights = attn_weights + attn_mask
281
-
282
- # upcast attention scores to fp32
283
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
284
- attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
285
 
286
- # apply attention scores to value states
287
- attn_output = torch.matmul(attn_weights, value_states)
 
 
 
 
 
 
 
 
 
 
 
288
 
289
- attn_output = attn_output.transpose(1, 2).contiguous()
290
- attn_output = attn_output.reshape(bsz, q_len, -1)
291
  attn_output = self.o_proj(attn_output)
292
-
293
- return attn_output, past_key_value
294
 
295
  def prepare_dynamic_mask(
296
  self,
@@ -318,136 +330,99 @@ class DogeDynamicMaskAttention(nn.Module):
318
  if attention_mask is not None:
319
  attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : hidden_states.shape[-2]] == min_type, min_type)
320
  return attn_mask
321
-
322
-
323
- class DogeSdpaDynamicMaskAttention(DogeDynamicMaskAttention):
324
-
325
- def forward(
326
  self,
327
- hidden_states: torch.Tensor,
328
- attention_mask: Optional[torch.Tensor] = None,
329
- position_ids: Optional[torch.LongTensor] = None,
330
- past_key_value: Optional[Cache] = None,
331
- cache_position: Optional[torch.LongTensor] = None,
332
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
333
  **kwargs,
334
- ) -> Tuple[torch.Tensor, Optional[Cache]]:
335
- bsz, q_len, _ = hidden_states.shape
336
-
337
- query_states = self.q_proj(hidden_states)
338
- key_states = self.k_proj(hidden_states)
339
- value_states = self.v_proj(hidden_states)
340
-
341
- query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
342
- key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
343
- value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
344
-
345
- cos, sin = position_embeddings
346
- query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
347
 
348
- if past_key_value is not None:
349
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
350
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
351
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
352
 
353
- # calculate dynamic mask from value_states
354
- dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
355
- dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
356
 
357
- attn_mask = self.prepare_dynamic_mask(
358
- hidden_states=hidden_states,
359
- dynamic_mask=dynamic_mask,
360
- dynamic_mask_ratio=self.dynamic_mask_ratio,
361
- attention_mask=attention_mask,
362
- )
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
- query_states = query_states.contiguous()
365
- key_states = key_states.contiguous()
366
- value_states = value_states.contiguous()
 
 
367
 
368
  # NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
369
  torch.backends.cuda.enable_cudnn_sdp(False)
370
  attn_output = F.scaled_dot_product_attention(
371
- query_states,
372
- key_states,
373
- value_states,
374
- attn_mask=attn_mask,
375
- dropout_p=self.attention_dropout if self.training else 0.0,
 
376
  enable_gqa=True,
377
  )
378
-
379
  attn_output = attn_output.transpose(1, 2).contiguous()
380
- attn_output = attn_output.view(bsz, q_len, -1)
381
- attn_output = self.o_proj(attn_output)
382
-
383
- return attn_output, past_key_value
384
-
385
-
386
- class DogeFlexDynamicMaskAttention(DogeDynamicMaskAttention):
387
-
388
- def forward(
389
  self,
390
- hidden_states: torch.Tensor,
391
- attention_mask: Optional[torch.Tensor] = None,
392
- position_ids: Optional[torch.LongTensor] = None,
393
- past_key_value: Optional[Cache] = None,
394
- cache_position: Optional[torch.LongTensor] = None,
395
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
396
  **kwargs,
397
- ) -> Tuple[torch.Tensor, Optional[Cache]]:
398
- bsz, q_len, _ = hidden_states.shape
399
-
400
- query_states = self.q_proj(hidden_states)
401
- key_states = self.k_proj(hidden_states)
402
- value_states = self.v_proj(hidden_states)
403
-
404
- query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
405
- key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
406
- value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
407
-
408
- cos, sin = position_embeddings
409
- query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
410
-
411
- if past_key_value is not None:
412
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
413
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
414
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
415
-
416
- dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
417
- dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
418
 
419
- attn_mask = self.prepare_dynamic_mask(
420
- hidden_states=hidden_states,
421
- dynamic_mask=dynamic_mask,
422
- dynamic_mask_ratio=self.dynamic_mask_ratio,
423
- attention_mask=attention_mask,
424
- )
425
  # TODO: flex_attention: Captured buffers that require grad are not yet supported.
426
  # NOTE: So we only use flex_attention in inference mode.
427
- def dynamic_mask_mod(score, batch, head, q_idx, kv_idx):
428
- score = score + attn_mask[batch][head][q_idx][kv_idx]
429
  return score
430
-
431
  attn_output = flex_attention(
432
- query_states,
433
- key_states,
434
- value_states,
435
- score_mod=dynamic_mask_mod,
 
436
  enable_gqa=True,
437
  )
438
-
439
  attn_output = attn_output.transpose(1, 2).contiguous()
440
- attn_output = attn_output.view(bsz, q_len, -1)
441
- attn_output = self.o_proj(attn_output)
442
-
443
- return attn_output, past_key_value
444
-
445
-
446
- DOGE_ATTENTION_CLASSES = {
447
- "flex_attention": DogeFlexDynamicMaskAttention,
448
- "eager": DogeDynamicMaskAttention,
449
- "sdpa": DogeSdpaDynamicMaskAttention,
450
- }
451
 
452
 
453
  class DogeMLP(nn.Module):
@@ -535,7 +510,7 @@ class DogeDecoderLayer(nn.Module):
535
  self.hidden_dropout = config.hidden_dropout
536
 
537
  self.pre_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
538
- self.self_attn = DOGE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
539
  self.pre_residual = Residual(config.hidden_size)
540
 
541
  self.post_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -551,32 +526,14 @@ class DogeDecoderLayer(nn.Module):
551
  output_attentions: Optional[bool] = False,
552
  use_cache: Optional[bool] = False,
553
  cache_position: Optional[torch.LongTensor] = None,
554
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
555
  **kwargs,
556
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
557
- """
558
- Args:
559
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
560
- attention_mask (`torch.FloatTensor`, *optional*):
561
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used.
562
- output_attentions (`bool`, *optional*):
563
- Whether or not to return the attentions tensors of all attention layers.
564
- See `attentions` under returned tensors for more detail.
565
- use_cache (`bool`, *optional*):
566
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`).
567
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
568
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
569
- Indices depicting the position of the input sequence tokens in the sequence
570
- position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
571
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head.
572
- kwargs (`dict`, *optional*):
573
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model
574
- """
575
 
576
  # sequence transformation
577
  residual = hidden_states
578
  hidden_states = self.pre_layernorm(hidden_states)
579
- hidden_states, present_key_value = self.self_attn(
580
  hidden_states=hidden_states,
581
  attention_mask=attention_mask,
582
  position_ids=position_ids,
@@ -597,25 +554,39 @@ class DogeDecoderLayer(nn.Module):
597
  hidden_states = self.post_residual(residual, hidden_states)
598
 
599
  outputs = (hidden_states,)
600
-
601
  if output_attentions:
602
  outputs += (self_attn_weights,)
603
 
604
- if use_cache:
605
- outputs += (present_key_value,)
606
-
607
  return outputs
608
 
609
 
610
- @add_start_docstrings("The bare Doge Model outputting raw hidden-states without any specific head on top.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  class DogePreTrainedModel(PreTrainedModel):
612
  config_class = DogeConfig
613
  base_model_prefix = "model"
614
  supports_gradient_checkpointing = True
615
  _no_split_modules = ["DogeDecoderLayer"]
616
  _skip_keys_device_placement = ["past_key_values"]
617
- _supports_flex_attn = True
618
  _supports_sdpa = True
 
619
  _supports_cache_class = True
620
  _supports_quantized_cache = True
621
  _supports_static_cache = True
@@ -635,10 +606,11 @@ class DogePreTrainedModel(PreTrainedModel):
635
  DOGE_INPUTS_DOCSTRING = r"""
636
  Args:
637
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
638
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
 
639
 
640
- Indices can be obtained using [`AutoTokenizer`].
641
- See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
642
 
643
  [What are input IDs?](../glossary#input-ids)
644
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -649,53 +621,75 @@ DOGE_INPUTS_DOCSTRING = r"""
649
 
650
  [What are attention masks?](../glossary#attention-mask)
651
 
652
- Indices can be obtained using [`AutoTokenizer`].
653
- See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
654
 
655
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`).
 
656
 
657
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs.
658
- See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
 
659
 
660
  - 1 indicates the head is **not masked**,
661
  - 0 indicates the head is **masked**.
662
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
663
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.
 
664
 
665
  [What are position IDs?](../glossary#position-ids)
666
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
667
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding.
668
- This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
 
669
 
670
  Two formats are allowed:
671
- - a [`~cache_utils.Cache`] instance, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
672
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format.
673
-
674
- The model will output the same cache format that is fed as input.
675
- If no `past_key_values` are passed, the legacy cache format will be returned.
676
-
677
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`.
 
 
 
 
 
678
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
679
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
680
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
 
681
  use_cache (`bool`, *optional*):
682
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`).
 
683
  output_attentions (`bool`, *optional*):
684
- Whether or not to return the attentions tensors of all attention layers.
685
- See `attentions` under returned tensors for more detail.
686
  output_hidden_states (`bool`, *optional*):
687
- Whether or not to return the hidden states of all layers.
688
- See `hidden_states` under returned tensors for more detail.
689
  return_dict (`bool`, *optional*):
690
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
691
  cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
692
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding.
693
- It is used to update the cache in the correct position and to infer the complete sequence length.
 
694
  """
695
 
696
 
697
- @add_start_docstrings("The bare Doge Model outputting raw hidden-states without any specific head on top.")
 
 
 
698
  class DogeModel(DogePreTrainedModel):
 
 
 
 
 
 
 
699
  def __init__(self, config: DogeConfig):
700
  super().__init__(config)
701
  self.config = config
@@ -732,6 +726,7 @@ class DogeModel(DogePreTrainedModel):
732
  output_hidden_states: Optional[bool] = None,
733
  return_dict: Optional[bool] = None,
734
  cache_position: Optional[torch.LongTensor] = None,
 
735
  ) -> Union[Tuple, BaseModelOutputWithPast]:
736
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
737
  output_hidden_states = (
@@ -752,33 +747,22 @@ class DogeModel(DogePreTrainedModel):
752
  if inputs_embeds is None:
753
  inputs_embeds = self.word_embed(input_ids)
754
 
755
- # kept for BC (non `Cache` `past_key_values` inputs)
756
- return_legacy_cache = False
757
- if use_cache and not isinstance(past_key_values, Cache):
758
- return_legacy_cache = True
759
- if past_key_values is None:
760
- past_key_values = DynamicCache()
761
- else:
762
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
763
- logger.warning_once(
764
- "We detected that you are passing `past_key_values` as a tuple of tuples."
765
- "This is deprecated and will be removed in v4.47."
766
- "Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
767
- )
768
 
769
  if cache_position is None:
770
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
771
  cache_position = torch.arange(
772
- past_seen_tokens,
773
- past_seen_tokens + inputs_embeds.shape[1],
774
- device=inputs_embeds.device,
775
  )
 
776
  if position_ids is None:
777
  position_ids = cache_position.unsqueeze(0)
778
 
779
  causal_mask = self._update_causal_mask(
780
  attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
781
  )
 
782
  hidden_states = inputs_embeds
783
 
784
  # create position embeddings to be shared across the decoder layers
@@ -787,7 +771,6 @@ class DogeModel(DogePreTrainedModel):
787
  # decoder layers
788
  all_hidden_states = () if output_hidden_states else None
789
  all_self_attns = () if output_attentions else None
790
- next_decoder_cache = None
791
 
792
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
793
  if output_hidden_states:
@@ -815,13 +798,11 @@ class DogeModel(DogePreTrainedModel):
815
  use_cache=use_cache,
816
  cache_position=cache_position,
817
  position_embeddings=position_embeddings,
 
818
  )
819
 
820
  hidden_states = layer_outputs[0]
821
 
822
- if use_cache:
823
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
824
-
825
  if output_attentions:
826
  all_self_attns += (layer_outputs[1],)
827
 
@@ -831,27 +812,21 @@ class DogeModel(DogePreTrainedModel):
831
  if output_hidden_states:
832
  all_hidden_states += (hidden_states,)
833
 
834
- next_cache = next_decoder_cache if use_cache else None
835
- if return_legacy_cache:
836
- next_cache = next_cache.to_legacy_cache()
837
-
838
- if not return_dict:
839
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
840
-
841
- return BaseModelOutputWithPast(
842
  last_hidden_state=hidden_states,
843
- past_key_values=next_cache,
844
  hidden_states=all_hidden_states,
845
  attentions=all_self_attns,
846
  )
 
847
 
848
  def _update_causal_mask(
849
  self,
850
- attention_mask: torch.Tensor = None,
851
- input_tensor: torch.Tensor = None,
852
- cache_position: torch.Tensor = None,
853
- past_key_values: Cache = None,
854
- output_attentions: bool = False,
855
  ):
856
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
857
  using_static_cache = isinstance(past_key_values, StaticCache)
@@ -892,15 +867,18 @@ class DogeModel(DogePreTrainedModel):
892
  **kwargs,
893
  ):
894
  """
895
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
 
896
 
897
  Args:
898
  attention_mask (`torch.Tensor`):
899
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
 
900
  sequence_length (`int`):
901
  The sequence length being processed.
902
  target_length (`int`):
903
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
 
904
  dtype (`torch.dtype`):
905
  The dtype to use for the 4D attention mask.
906
  device (`torch.device`):
@@ -935,8 +913,12 @@ class DogeModel(DogePreTrainedModel):
935
  return causal_mask
936
 
937
 
 
 
 
938
  class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
939
  _tied_weights_keys = ["lm_head.weight"]
 
940
 
941
  def __init__(self, config: DogeConfig):
942
  super().__init__(config)
@@ -982,22 +964,38 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
982
  return_dict: Optional[bool] = None,
983
  cache_position: Optional[torch.LongTensor] = None,
984
  num_logits_to_keep: int = 0,
985
- **kwargs,
986
  ) -> Union[Tuple, CausalLMOutputWithPast]:
987
  r"""
988
  Args:
989
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
990
- Labels for computing the masked language modeling loss.
991
- Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring).
992
- Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
993
 
994
  num_logits_to_keep (`int`, *optional*):
995
- Calculate logits for the last `num_logits_to_keep` tokens.
996
- If `0`, calculate logits for all `input_ids` (special case).
997
- Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
998
 
999
  Returns:
1000
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1001
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1002
  output_hidden_states = (
1003
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
19
  """PyTorch Doge model."""
20
 
21
  import math
22
+ from typing import Callable, List, Optional, Tuple, Union
23
 
24
  import torch
25
  import torch.nn.functional as F
 
36
  )
37
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
38
  from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.processing_utils import Unpack
40
  from transformers.utils import (
41
+ LossKwargs,
42
  add_start_docstrings,
43
  add_start_docstrings_to_model_forward,
44
  is_torch_greater_or_equal,
 
207
 
208
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
209
  super().__init__()
 
210
  self.config = config
211
  self.layer_idx = layer_idx
212
+ self.head_dim = config.hidden_size // config.num_attention_heads
213
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
214
+ self.scaling = self.head_dim ** -0.5
 
 
 
 
 
 
 
 
215
  self.attention_dropout = config.attention_dropout
216
  self.dynamic_mask_ratio = config.dynamic_mask_ratio
217
 
218
+ self.ALL_ATTENTION_FUNCTIONS = {
219
+ "eager": self.eager_attention_forward,
220
+ "sdpa": self.sdpa_attention_forward,
221
+ "flex_attention": self.flex_attention_forward,
222
+ }
223
+
224
  # Q K V O projections
225
+ self.q_proj = nn.Linear(
226
+ config.hidden_size,
227
+ config.num_attention_heads * self.head_dim,
228
+ bias=config.hidden_bias
229
+ )
230
+ self.k_proj = nn.Linear(
231
+ config.hidden_size,
232
+ config.num_key_value_heads * self.head_dim,
233
+ bias=config.hidden_bias
234
+ )
235
+ self.v_proj = nn.Linear(
236
+ config.hidden_size,
237
+ config.num_key_value_heads * self.head_dim,
238
+ bias=config.hidden_bias
239
+ )
240
  # dynamic mask for the QK^T attention score matrix
241
+ self.A = nn.Parameter(
242
+ torch.ones(config.num_attention_heads)
243
+ )
244
+ self.dt_proj = nn.Linear(
245
+ config.num_key_value_heads * self.head_dim,
246
+ config.num_attention_heads,
247
+ bias=config.hidden_bias
248
+ )
249
+ self.o_proj = nn.Linear(
250
+ config.num_attention_heads * self.head_dim,
251
+ config.hidden_size,
252
+ bias=config.hidden_bias
253
+ )
254
 
255
  def forward(
256
  self,
257
  hidden_states: torch.Tensor,
258
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
259
  attention_mask: Optional[torch.Tensor] = None,
 
260
  past_key_value: Optional[Cache] = None,
261
  cache_position: Optional[torch.LongTensor] = None,
 
262
  **kwargs,
263
  ) -> Tuple[torch.Tensor, Optional[Cache]]:
264
+ input_shape = hidden_states.shape[:-1]
265
+ hidden_shape = (*input_shape, -1, self.head_dim)
266
 
267
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
268
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
269
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
 
 
 
 
270
 
271
  cos, sin = position_embeddings
272
  query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
 
277
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
278
 
279
  # calculate dynamic mask from value_states
280
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1))
281
  dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
 
 
 
 
 
 
 
 
 
282
  attn_mask = self.prepare_dynamic_mask(
283
  hidden_states=hidden_states,
284
  dynamic_mask=dynamic_mask,
285
  dynamic_mask_ratio=self.dynamic_mask_ratio,
286
  attention_mask=attention_mask,
287
  )
 
 
 
 
 
288
 
289
+ attention_interface: Callable = self.eager_attention_forward
290
+ if self.config._attn_implementation != "eager":
291
+ attention_interface = self.ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
292
+
293
+ attn_output = attention_interface(
294
+ query_states,
295
+ key_states,
296
+ value_states,
297
+ attention_mask=attn_mask,
298
+ dropout=0.0 if not self.training else self.attention_dropout,
299
+ scaling=self.scaling,
300
+ **kwargs,
301
+ )
302
 
303
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
 
304
  attn_output = self.o_proj(attn_output)
305
+ return attn_output
 
306
 
307
  def prepare_dynamic_mask(
308
  self,
 
330
  if attention_mask is not None:
331
  attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : hidden_states.shape[-2]] == min_type, min_type)
332
  return attn_mask
333
+
334
+ def eager_attention_forward(
 
 
 
335
  self,
336
+ query: torch.Tensor,
337
+ key: torch.Tensor,
338
+ value: torch.Tensor,
339
+ attention_mask: Optional[torch.Tensor],
340
+ scaling: float,
341
+ dropout: float = 0.0,
342
  **kwargs,
343
+ ) -> torch.Tensor:
344
+ key_states = repeat_kv(key, self.num_key_value_groups)
345
+ value_states = repeat_kv(value, self.num_key_value_groups)
 
 
 
 
 
 
 
 
 
 
346
 
347
+ # compute attention scores matrix
348
+ attn_weights = torch.matmul(query, key_states.transpose(-1, -2)) * scaling
349
+ if attention_mask is not None:
350
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
351
+ attn_weights = attn_weights + causal_mask
352
 
353
+ # upcast attention scores to fp32
354
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
355
+ attn_weights = F.dropout(attn_weights, p=dropout, training=self.training)
356
 
357
+ # apply attention scores to value states
358
+ attn_output = torch.matmul(attn_weights, value_states)
359
+ attn_output = attn_output.transpose(1, 2).contiguous()
360
+ return attn_output
361
+
362
+ def sdpa_attention_forward(
363
+ self,
364
+ query: torch.Tensor,
365
+ key: torch.Tensor,
366
+ value: torch.Tensor,
367
+ attention_mask: Optional[torch.Tensor],
368
+ scaling: float,
369
+ dropout: float = 0.0,
370
+ **kwargs,
371
+ ) -> torch.Tensor:
372
+ causal_mask = attention_mask
373
+ if attention_mask is not None:
374
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
375
 
376
+ # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
377
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
378
+ query = query.contiguous()
379
+ key = key.contiguous()
380
+ value = value.contiguous()
381
 
382
  # NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
383
  torch.backends.cuda.enable_cudnn_sdp(False)
384
  attn_output = F.scaled_dot_product_attention(
385
+ query,
386
+ key,
387
+ value,
388
+ attn_mask=causal_mask,
389
+ dropout_p=dropout,
390
+ scale=scaling,
391
  enable_gqa=True,
392
  )
 
393
  attn_output = attn_output.transpose(1, 2).contiguous()
394
+ return attn_output
395
+
396
+ def flex_attention_forward(
 
 
 
 
 
 
397
  self,
398
+ query: torch.Tensor,
399
+ key: torch.Tensor,
400
+ value: torch.Tensor,
401
+ attention_mask: Optional[torch.Tensor],
402
+ scaling: float,
403
+ dropout: float = 0.0,
404
  **kwargs,
405
+ ) -> torch.Tensor:
406
+ causal_mask = attention_mask
407
+ if attention_mask is not None:
408
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
 
 
 
 
 
 
410
  # TODO: flex_attention: Captured buffers that require grad are not yet supported.
411
  # NOTE: So we only use flex_attention in inference mode.
412
+ def mask_mod(score, batch, head, q_idx, kv_idx):
413
+ score = score + causal_mask[batch][head][q_idx][kv_idx]
414
  return score
415
+
416
  attn_output = flex_attention(
417
+ query,
418
+ key,
419
+ value,
420
+ score_mod=mask_mod,
421
+ scale=scaling,
422
  enable_gqa=True,
423
  )
 
424
  attn_output = attn_output.transpose(1, 2).contiguous()
425
+ return attn_output
 
 
 
 
 
 
 
 
 
 
426
 
427
 
428
  class DogeMLP(nn.Module):
 
510
  self.hidden_dropout = config.hidden_dropout
511
 
512
  self.pre_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
513
+ self.self_attn = DogeDynamicMaskAttention(config=config, layer_idx=layer_idx)
514
  self.pre_residual = Residual(config.hidden_size)
515
 
516
  self.post_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
526
  output_attentions: Optional[bool] = False,
527
  use_cache: Optional[bool] = False,
528
  cache_position: Optional[torch.LongTensor] = None,
529
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
530
  **kwargs,
531
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
  # sequence transformation
534
  residual = hidden_states
535
  hidden_states = self.pre_layernorm(hidden_states)
536
+ hidden_states = self.self_attn(
537
  hidden_states=hidden_states,
538
  attention_mask=attention_mask,
539
  position_ids=position_ids,
 
554
  hidden_states = self.post_residual(residual, hidden_states)
555
 
556
  outputs = (hidden_states,)
 
557
  if output_attentions:
558
  outputs += (self_attn_weights,)
559
 
 
 
 
560
  return outputs
561
 
562
 
563
+ DOGE_START_DOCSTRING = r"""
564
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
565
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
566
+ etc.)
567
+
568
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
569
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
570
+ and behavior.
571
+
572
+ Parameters:
573
+ config ([`DogeConfig`]):
574
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
575
+ load the weights associated with the model, only the configuration. Check out the
576
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
577
+ """
578
+ @add_start_docstrings(
579
+ "The bare Doge Model outputting raw hidden-states without any specific head on top.",
580
+ DOGE_START_DOCSTRING,
581
+ )
582
  class DogePreTrainedModel(PreTrainedModel):
583
  config_class = DogeConfig
584
  base_model_prefix = "model"
585
  supports_gradient_checkpointing = True
586
  _no_split_modules = ["DogeDecoderLayer"]
587
  _skip_keys_device_placement = ["past_key_values"]
 
588
  _supports_sdpa = True
589
+ _supports_flex_attn = True
590
  _supports_cache_class = True
591
  _supports_quantized_cache = True
592
  _supports_static_cache = True
 
606
  DOGE_INPUTS_DOCSTRING = r"""
607
  Args:
608
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
609
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
610
+ it.
611
 
612
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
613
+ [`PreTrainedTokenizer.__call__`] for details.
614
 
615
  [What are input IDs?](../glossary#input-ids)
616
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
 
621
 
622
  [What are attention masks?](../glossary#attention-mask)
623
 
624
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
625
+ [`PreTrainedTokenizer.__call__`] for details.
626
 
627
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
628
+ `past_key_values`).
629
 
630
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
631
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
632
+ information on the default strategy.
633
 
634
  - 1 indicates the head is **not masked**,
635
  - 0 indicates the head is **masked**.
636
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
637
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
638
+ config.n_positions - 1]`.
639
 
640
  [What are position IDs?](../glossary#position-ids)
641
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
642
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
643
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
644
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
645
 
646
  Two formats are allowed:
647
+ - a [`~cache_utils.Cache`] instance, see our
648
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
649
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
650
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
651
+ cache format.
652
+
653
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
654
+ legacy cache format will be returned.
655
+
656
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
657
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
658
+ of shape `(batch_size, sequence_length)`.
659
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
660
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
661
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
662
+ model's internal embedding lookup matrix.
663
  use_cache (`bool`, *optional*):
664
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
665
+ `past_key_values`).
666
  output_attentions (`bool`, *optional*):
667
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
668
+ tensors for more detail.
669
  output_hidden_states (`bool`, *optional*):
670
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
671
+ more detail.
672
  return_dict (`bool`, *optional*):
673
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
674
  cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
675
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
676
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
677
+ the complete sequence length.
678
  """
679
 
680
 
681
+ @add_start_docstrings(
682
+ "The bare Doge Model outputting raw hidden-states without any specific head on top.",
683
+ DOGE_START_DOCSTRING,
684
+ )
685
  class DogeModel(DogePreTrainedModel):
686
+ """
687
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DogeDecoderLayer`]
688
+
689
+ Args:
690
+ config: DogeConfig
691
+ """
692
+
693
  def __init__(self, config: DogeConfig):
694
  super().__init__(config)
695
  self.config = config
 
726
  output_hidden_states: Optional[bool] = None,
727
  return_dict: Optional[bool] = None,
728
  cache_position: Optional[torch.LongTensor] = None,
729
+ **kwargs,
730
  ) -> Union[Tuple, BaseModelOutputWithPast]:
731
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
732
  output_hidden_states = (
 
747
  if inputs_embeds is None:
748
  inputs_embeds = self.word_embed(input_ids)
749
 
750
+ if use_cache and past_key_values is None:
751
+ past_key_values = DynamicCache()
 
 
 
 
 
 
 
 
 
 
 
752
 
753
  if cache_position is None:
754
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
755
  cache_position = torch.arange(
756
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
757
  )
758
+
759
  if position_ids is None:
760
  position_ids = cache_position.unsqueeze(0)
761
 
762
  causal_mask = self._update_causal_mask(
763
  attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
764
  )
765
+
766
  hidden_states = inputs_embeds
767
 
768
  # create position embeddings to be shared across the decoder layers
 
771
  # decoder layers
772
  all_hidden_states = () if output_hidden_states else None
773
  all_self_attns = () if output_attentions else None
 
774
 
775
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
776
  if output_hidden_states:
 
798
  use_cache=use_cache,
799
  cache_position=cache_position,
800
  position_embeddings=position_embeddings,
801
+ **kwargs,
802
  )
803
 
804
  hidden_states = layer_outputs[0]
805
 
 
 
 
806
  if output_attentions:
807
  all_self_attns += (layer_outputs[1],)
808
 
 
812
  if output_hidden_states:
813
  all_hidden_states += (hidden_states,)
814
 
815
+ output = BaseModelOutputWithPast(
 
 
 
 
 
 
 
816
  last_hidden_state=hidden_states,
817
+ past_key_values=past_key_values if use_cache else None,
818
  hidden_states=all_hidden_states,
819
  attentions=all_self_attns,
820
  )
821
+ return output if return_dict else output.to_tuple()
822
 
823
  def _update_causal_mask(
824
  self,
825
+ attention_mask: torch.Tensor,
826
+ input_tensor: torch.Tensor,
827
+ cache_position: torch.Tensor,
828
+ past_key_values: Cache,
829
+ output_attentions: bool,
830
  ):
831
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
832
  using_static_cache = isinstance(past_key_values, StaticCache)
 
867
  **kwargs,
868
  ):
869
  """
870
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
871
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
872
 
873
  Args:
874
  attention_mask (`torch.Tensor`):
875
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
876
+ `(batch_size, 1, query_length, key_value_length)`.
877
  sequence_length (`int`):
878
  The sequence length being processed.
879
  target_length (`int`):
880
+ The target length: when generating with static cache, the mask should be as long as the static cache,
881
+ to account for the 0 padding, the part of the cache that is not filled yet.
882
  dtype (`torch.dtype`):
883
  The dtype to use for the 4D attention mask.
884
  device (`torch.device`):
 
913
  return causal_mask
914
 
915
 
916
+ class KwargsForCausalLM(LossKwargs): ...
917
+
918
+
919
  class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
920
  _tied_weights_keys = ["lm_head.weight"]
921
+ _tp_plan = {"lm_head": "colwise_rep"}
922
 
923
  def __init__(self, config: DogeConfig):
924
  super().__init__(config)
 
964
  return_dict: Optional[bool] = None,
965
  cache_position: Optional[torch.LongTensor] = None,
966
  num_logits_to_keep: int = 0,
967
+ **kwargs: Unpack[KwargsForCausalLM],
968
  ) -> Union[Tuple, CausalLMOutputWithPast]:
969
  r"""
970
  Args:
971
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
972
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
973
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
974
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
975
 
976
  num_logits_to_keep (`int`, *optional*):
977
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
978
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
979
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
980
 
981
  Returns:
982
+
983
+ Example:
984
+
985
+ ```python
986
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
987
+
988
+ >>> model = AutoModelForCausalLM.from_pretrained("JingzeShi/Doge-20M-Instruct")
989
+ >>> tokenizer = AutoTokenizer.from_pretrained("JingzeShi/Doge-20M-Instruct")
990
+
991
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
992
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
993
+
994
+ >>> # Generate
995
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
996
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
997
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
998
+ ```"""
999
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1000
  output_hidden_states = (
1001
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states