wuhp commited on
Commit
1520350
·
verified ·
1 Parent(s): 94baf3a

Update myr1/modeling_deepseek.py

Browse files
Files changed (1) hide show
  1. myr1/modeling_deepseek.py +536 -195
myr1/modeling_deepseek.py CHANGED
@@ -54,17 +54,19 @@ logger = logging.get_logger(__name__)
54
 
55
  # If flash-attn is available
56
  if is_flash_attn_2_available():
57
- from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_paged_kv
58
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
59
 
60
  # This helps make `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
61
  if is_torch_fx_available():
62
  if not is_torch_greater_or_equal_than_1_13:
63
  import torch.fx
 
64
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
65
 
66
  _CONFIG_FOR_DOC = "DeepseekV3Config"
67
 
 
68
  # ==============================================================================
69
  # Rotary Embedding Helpers
70
  # ==============================================================================
@@ -80,9 +82,11 @@ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.T
80
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
81
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
82
  max_seqlen_in_batch = seqlens_in_batch.max().item()
 
83
  cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
84
  return indices, cu_seqlens, max_seqlen_in_batch
85
 
 
86
  # ==============================================================================
87
  # Normalization Layers
88
  # ==============================================================================
@@ -98,13 +102,16 @@ class DeepseekV3RMSNorm(nn.Module):
98
  self.variance_epsilon = eps
99
 
100
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
101
  input_dtype = hidden_states.dtype
102
  variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
103
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
104
  return (self.weight * hidden_states).to(input_dtype)
105
 
 
106
  ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)
107
 
 
108
  # ==============================================================================
109
  # Rotary Embeddings
110
  # ==============================================================================
@@ -125,20 +132,25 @@ class DeepseekV3RotaryEmbedding(nn.Module):
125
  self.max_position_embeddings = max_position_embeddings
126
  self.base = base
127
 
128
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
 
 
129
  self.register_buffer("inv_freq", inv_freq, persistent=False)
130
- # Build here to make torch.jit.trace work.
 
131
  self._set_cos_sin_cache(
132
  seq_len=max_position_embeddings,
133
  device=self.inv_freq.device,
134
  dtype=torch.get_default_dtype(),
135
  )
136
- self.max_seq_len_cached = max_position_embeddings
137
 
138
  def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
139
  self.max_seq_len_cached = seq_len
140
- t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
141
- freqs = torch.outer(t, self.inv_freq.to(device))
 
 
142
  emb = torch.cat((freqs, freqs), dim=-1)
143
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
144
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
@@ -147,17 +159,16 @@ class DeepseekV3RotaryEmbedding(nn.Module):
147
  """
148
  x: [batch_size, num_heads, seq_len, head_size]
149
  """
150
- if seq_len is None:
151
- seq_len = x.shape[-2]
152
- if seq_len > self.max_seq_len_cached:
153
- self._set_cos_sin_cache(seq_len, device=x.device, dtype=x.dtype)
154
- return (self.cos_cached[:seq_len].to(x.dtype),
155
- self.sin_cached[:seq_len].to(x.dtype))
156
 
157
 
158
  class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
159
  """
160
- RoPE extended with linear scaling. Credits to the Reddit user /u/kaiokendev.
161
  """
162
  def __init__(
163
  self,
@@ -172,7 +183,8 @@ class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
172
 
173
  def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
174
  self.max_seq_len_cached = seq_len
175
- t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) / self.scaling_factor
 
176
  freqs = torch.outer(t, self.inv_freq)
177
  emb = torch.cat((freqs, freqs), dim=-1)
178
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
@@ -182,7 +194,7 @@ class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
182
  class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
183
  """
184
  RoPE extended with Dynamic NTK scaling.
185
- Credits to the Reddit users /u/bloc97 and /u/emozilla.
186
  """
187
  def __init__(
188
  self,
@@ -197,28 +209,34 @@ class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
197
 
198
  def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
199
  self.max_seq_len_cached = seq_len
 
200
  if seq_len > self.max_position_embeddings:
201
  base = self.base * (
202
  (self.scaling_factor * seq_len / self.max_position_embeddings)
203
  - (self.scaling_factor - 1)
204
  ) ** (self.dim / (self.dim - 2))
205
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
 
 
206
  self.register_buffer("inv_freq", inv_freq, persistent=False)
207
- t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
 
208
  freqs = torch.outer(t, self.inv_freq)
209
  emb = torch.cat((freqs, freqs), dim=-1)
210
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
211
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
212
 
213
 
214
- # Extra Yarn-based formulas from your original code
215
  def yarn_find_correction_dim(
216
  num_rotations: float,
217
  dim: int,
218
  base: int = 10000,
219
  max_position_embeddings: int = 2048
220
  ):
221
- return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
 
 
222
 
223
 
224
  def yarn_find_correction_range(
@@ -228,8 +246,13 @@ def yarn_find_correction_range(
228
  base: int = 10000,
229
  max_position_embeddings: int = 2048
230
  ):
231
- low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
232
- high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
 
 
 
 
 
233
  return max(low, 0), min(high, dim - 1)
234
 
235
 
@@ -275,21 +298,39 @@ class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
275
  def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
276
  self.max_seq_len_cached = seq_len
277
  dim = self.dim
278
- freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
279
- freq_inter = 1.0 / (self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
280
- low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow, dim, self.base, self.original_max_position_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32)
282
  inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
283
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
284
  t = torch.arange(seq_len, device=device, dtype=torch.float32)
285
  freqs = torch.outer(t, inv_freq)
286
- _mscale = float(yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim))
 
 
 
 
287
  emb = torch.cat((freqs, freqs), dim=-1)
288
  self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False)
289
  self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
290
 
291
 
292
- # ==============================================================================
293
  # General Rotary helper functions
294
  # ==============================================================================
295
 
@@ -339,7 +380,9 @@ class DeepseekV3MLP(nn.Module):
339
  super().__init__()
340
  self.config = config
341
  self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
342
- self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
 
 
343
 
344
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
345
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@@ -370,9 +413,14 @@ class MoEGate(nn.Module):
370
  self.norm_topk_prob = config.norm_topk_prob
371
  self.gating_dim = config.hidden_size
372
 
 
373
  self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
 
374
  if self.topk_method == "noaux_tc":
375
- self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
 
 
 
376
  self.reset_parameters()
377
 
378
  def reset_parameters(self):
@@ -385,29 +433,46 @@ class MoEGate(nn.Module):
385
  Compute gating scores and select top-k experts.
386
  """
387
  bsz, seq_len, h = hidden_states.shape
 
 
388
  logits = F.linear(hidden_states.float(), self.weight.float(), None)
389
  if self.scoring_func == "sigmoid":
390
  scores = logits.sigmoid()
391
  else:
392
- raise NotImplementedError(f"Unsupported gating scoring function: {self.scoring_func}")
 
 
393
 
 
394
  if self.topk_method == "noaux_tc":
 
 
395
  scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
396
- group_scores = (scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1))
397
- group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
 
 
 
 
398
  group_mask = torch.zeros_like(group_scores)
399
  group_mask.scatter_(1, group_idx, 1)
400
- score_mask = group_mask.unsqueeze(-1).expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group).reshape(bsz * seq_len, -1)
 
 
401
  tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
402
  _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
403
  topk_weight = scores_for_choice.gather(1, topk_idx)
404
  else:
405
- raise NotImplementedError(f"Unsupported topk_method: {self.topk_method}")
 
 
406
 
 
407
  if self.top_k > 1 and self.norm_topk_prob:
408
  denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
409
  topk_weight = topk_weight / denominator
410
 
 
411
  topk_weight = topk_weight * self.routed_scaling_factor
412
 
413
  return topk_idx, topk_weight
@@ -430,43 +495,62 @@ class DeepseekV3MoE(nn.Module):
430
  self.experts_per_rank = config.n_routed_experts // config.ep_size
431
  self.ep_rank = dist.get_rank()
432
 
 
433
  experts_list = []
434
  for i in range(config.n_routed_experts):
 
435
  if self.ep_size > 1:
436
  if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank:
437
- experts_list.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size))
 
 
438
  else:
439
  experts_list.append(None)
440
  else:
441
- experts_list.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size))
 
 
442
  self.experts = nn.ModuleList(experts_list)
443
 
 
444
  self.gate = MoEGate(config)
445
 
 
446
  if config.n_shared_experts is not None:
447
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
448
- self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size)
 
 
449
  else:
450
  self.shared_experts = None
451
 
452
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
453
  identity = hidden_states
454
  orig_shape = hidden_states.shape
 
455
  topk_idx, topk_weight = self.gate(hidden_states)
 
456
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
 
 
457
  if not self.training:
458
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
459
  else:
 
 
 
460
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
 
 
461
  if self.shared_experts is not None:
462
  y = y + self.shared_experts(identity)
 
463
  return y
464
 
465
  @torch.no_grad()
466
  def moe_infer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
467
  """
468
- MoE inference path for each token. Processes experts in parallel and combines
469
- results via an efficient scatter-add.
470
  """
471
  cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
472
  cnts.scatter_(1, topk_ids, 1)
@@ -475,16 +559,30 @@ class DeepseekV3MoE(nn.Module):
475
  sorted_tokens = x[idxs // topk_ids.shape[1]]
476
  sorted_tokens_shape = sorted_tokens.shape
477
 
 
478
  if self.ep_size > 1:
479
  tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
480
  tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
481
  dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
482
- output_splits = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(1).cpu().numpy().tolist()
483
- gathered_tokens = sorted_tokens.new_empty(tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1])
 
 
 
 
 
 
 
 
484
  input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
485
- dist.all_to_all(list(gathered_tokens.split(input_split_sizes)), list(sorted_tokens.split(input_split_sizes)))
486
- tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0)
487
- gatherd_idxs = np.zeros((gathered_tokens.shape[0],), dtype=np.int32)
 
 
 
 
 
488
  s = 0
489
  for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
490
  gatherd_idxs[s : s + k] = i % self.experts_per_rank
@@ -494,8 +592,10 @@ class DeepseekV3MoE(nn.Module):
494
  tokens_per_expert = tokens_per_expert_post_gather
495
 
496
  tokens_per_expert = tokens_per_expert.cpu().numpy()
 
497
  outputs = []
498
  start_idx = 0
 
499
  for i, num_tokens in enumerate(tokens_per_expert):
500
  end_idx = start_idx + num_tokens
501
  if num_tokens == 0:
@@ -505,19 +605,32 @@ class DeepseekV3MoE(nn.Module):
505
  expert_out = expert(tokens_for_this_expert) if expert else tokens_for_this_expert
506
  outputs.append(expert_out)
507
  start_idx = end_idx
508
- outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
 
 
 
 
 
 
509
  if self.ep_size > 1:
510
  new_x = torch.empty_like(outs)
511
  new_x[gatherd_idxs] = outs
512
  gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
513
- dist.all_to_all(list(gathered_tokens.split(input_split_sizes)), list(new_x.split(output_splits)))
 
 
 
514
  outs = gathered_tokens
 
515
  new_x = torch.empty_like(outs)
516
  new_x[idxs] = outs
517
- final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype)
518
- .mul_(topk_weight.unsqueeze(dim=-1))
519
- .sum(dim=1)
520
- .type(new_x.dtype))
 
 
 
521
  return final_out
522
 
523
 
@@ -530,7 +643,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
530
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
531
  if n_rep == 1:
532
  return hidden_states
533
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
 
534
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
535
 
536
 
@@ -562,59 +677,96 @@ class DeepseekV3Attention(nn.Module):
562
 
563
  self.is_causal = True
564
 
 
565
  if self.q_lora_rank is None:
566
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)
 
 
567
  else:
568
- self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
 
 
569
  self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
570
- self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
 
 
571
 
572
- self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size,
573
- config.kv_lora_rank + config.qk_rope_head_dim,
574
- bias=config.attention_bias)
 
 
 
575
  self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
576
- self.kv_b_proj = nn.Linear(config.kv_lora_rank,
577
- self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
578
- bias=False)
 
 
579
 
 
580
  self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias)
581
 
 
582
  self._init_rope()
583
 
 
584
  self.softmax_scale = self.q_head_dim ** (-0.5)
585
  if self.config.rope_scaling is not None:
 
586
  mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
587
  scaling_factor = self.config.rope_scaling["factor"]
588
  if mscale_all_dim:
 
589
  self.softmax_scale *= yarn_get_mscale(scaling_factor, mscale_all_dim) ** 2
590
 
591
  def _init_rope(self):
 
 
 
592
  if self.config.rope_scaling is None:
593
- self.rotary_emb = DeepseekV3RotaryEmbedding(self.qk_rope_head_dim,
594
- max_position_embeddings=self.max_position_embeddings,
595
- base=self.rope_theta)
 
 
596
  else:
597
  scaling_type = self.config.rope_scaling["type"]
598
  scaling_factor = self.config.rope_scaling["factor"]
 
599
  if scaling_type == "linear":
600
- self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(self.qk_rope_head_dim,
601
- max_position_embeddings=self.max_position_embeddings,
602
- scaling_factor=scaling_factor,
603
- base=self.rope_theta)
 
 
604
  elif scaling_type == "dynamic":
605
- self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(self.qk_rope_head_dim,
606
- max_position_embeddings=self.max_position_embeddings,
607
- scaling_factor=scaling_factor,
608
- base=self.rope_theta)
 
 
609
  elif scaling_type == "yarn":
610
- kwargs = {key: self.config.rope_scaling[key]
611
- for key in ["original_max_position_embeddings", "beta_fast", "beta_slow", "mscale", "mscale_all_dim"]
612
- if key in self.config.rope_scaling}
613
- self.rotary_emb = DeepseekV3YarnRotaryEmbedding(self.qk_rope_head_dim,
614
- max_position_embeddings=self.max_position_embeddings,
615
- scaling_factor=scaling_factor,
616
- base=self.rope_theta,
617
- **kwargs)
 
 
 
 
 
 
 
 
 
 
618
  else:
619
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
620
 
@@ -628,10 +780,17 @@ class DeepseekV3Attention(nn.Module):
628
  use_cache: bool = False,
629
  **kwargs,
630
  ):
 
 
 
631
  if "padding_mask" in kwargs:
632
- warnings.warn("Passing `padding_mask` is deprecated. Use `attention_mask` instead.")
 
 
 
633
  bsz, q_len, _ = hidden_states.size()
634
 
 
635
  if self.q_lora_rank is None:
636
  q = self.q_proj(hidden_states)
637
  else:
@@ -639,51 +798,75 @@ class DeepseekV3Attention(nn.Module):
639
  q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
640
  q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
641
 
 
642
  compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
643
- compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
 
 
644
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
645
- kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
646
- .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
647
- .transpose(1, 2))
648
- k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
 
 
 
 
649
  kv_seq_len = value_states.shape[-2]
650
  if past_key_value is not None:
651
  if self.layer_idx is None:
652
- raise ValueError(f"Missing `layer_idx` for caching in {self.__class__.__name__}.")
 
 
653
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
654
 
655
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 
 
656
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
657
 
658
  query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
659
- query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
660
  query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
661
 
662
  key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
663
- key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
664
  key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
665
 
666
  if past_key_value is not None:
667
- cache_kwargs = {"sin": sin, "cos": cos}
668
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
669
 
670
- attn_weights = (torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale)
671
  if attention_mask is not None:
672
  attn_weights = attn_weights + attention_mask
673
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
674
- attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
 
 
 
675
  attn_output = torch.matmul(attn_weights, value_states)
676
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.num_heads * self.v_head_dim)
 
 
 
677
  attn_output = self.o_proj(attn_output)
678
 
679
  if not output_attentions:
680
  attn_weights = None
 
681
  return attn_output, attn_weights, past_key_value
682
 
683
 
684
  class DeepseekV3FlashAttention2(DeepseekV3Attention):
685
  """
686
- DeepseekV3 flash attention module using flash_attn APIs.
 
687
  """
688
  def __init__(self, *args, **kwargs):
689
  super().__init__(*args, **kwargs)
@@ -699,11 +882,14 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
699
  use_cache: bool = False,
700
  **kwargs,
701
  ):
 
702
  if "padding_mask" in kwargs:
703
- warnings.warn("Passing `padding_mask` is deprecated. Use `attention_mask` instead.")
 
 
704
  attention_mask = kwargs.pop("padding_mask")
705
 
706
- output_attentions = False # flash_attn2 does not expose attention probs
707
 
708
  bsz, q_len, _ = hidden_states.shape
709
  if self.q_lora_rank is None:
@@ -714,45 +900,64 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
714
  q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
715
 
716
  compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
717
- compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
 
 
718
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
719
- kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
720
- .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
721
- .transpose(1, 2))
722
- k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
 
 
 
 
723
  kv_seq_len = value_states.shape[-2]
724
  if past_key_value is not None:
725
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
726
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
727
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
728
 
729
  query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
730
- query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
731
  query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
732
 
733
  key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
734
- key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
735
  key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
736
 
737
  if self.q_head_dim != self.v_head_dim:
 
738
  value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
739
 
740
  if past_key_value is not None:
741
- cache_kwargs = {"sin": sin, "cos": sin}
742
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
743
 
 
744
  query_states = query_states.transpose(1, 2)
745
  key_states = key_states.transpose(1, 2)
746
  value_states = value_states.transpose(1, 2)
747
 
748
  dropout_rate = self.attention_dropout if self.training else 0.0
 
 
749
  input_dtype = query_states.dtype
750
  if input_dtype == torch.float32:
751
- target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype
 
 
 
 
 
752
  query_states = query_states.to(target_dtype)
753
  key_states = key_states.to(target_dtype)
754
  value_states = value_states.to(target_dtype)
755
 
 
756
  attn_output = self._flash_attention_forward(
757
  query_states,
758
  key_states,
@@ -764,10 +969,12 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
764
  )
765
 
766
  if self.q_head_dim != self.v_head_dim:
767
- attn_output = attn_output[:, :, :, :self.v_head_dim]
768
 
 
769
  attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
770
  attn_output = self.o_proj(attn_output)
 
771
  return attn_output, None, past_key_value
772
 
773
  def _flash_attention_forward(
@@ -780,9 +987,13 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
780
  dropout: float = 0.0,
781
  softmax_scale: Optional[float] = None,
782
  ) -> torch.Tensor:
 
 
 
783
  if not self._flash_attn_uses_top_left_mask:
784
  causal = self.is_causal
785
  else:
 
786
  causal = self.is_causal and query_length != 1
787
 
788
  if attention_mask is not None:
@@ -792,7 +1003,9 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
792
  value_states,
793
  indices_q,
794
  (cu_seqlens_q, cu_seqlens_k),
795
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k)) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
 
 
796
  attn_output_unpad = flash_attn_varlen_func(
797
  query_states,
798
  key_states,
@@ -805,7 +1018,9 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
805
  softmax_scale=softmax_scale,
806
  causal=causal,
807
  )
808
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
 
809
  else:
810
  attn_output = flash_attn_func(
811
  query_states,
@@ -815,6 +1030,7 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
815
  softmax_scale=softmax_scale,
816
  causal=causal,
817
  )
 
818
  return attn_output
819
 
820
  def _upad_input(
@@ -825,29 +1041,53 @@ class DeepseekV3FlashAttention2(DeepseekV3Attention):
825
  attention_mask: torch.Tensor,
826
  query_length: int,
827
  ):
 
 
 
828
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
829
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
830
 
831
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
832
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
 
 
 
 
 
 
833
  if query_length == kv_seq_len:
834
- query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k)
 
 
 
835
  cu_seqlens_q = cu_seqlens_k
836
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
837
  indices_q = indices_k
838
  elif query_length == 1:
839
  max_seqlen_in_batch_q = 1
840
- cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device)
 
 
841
  indices_q = cu_seqlens_q[:-1]
842
  query_layer = query_layer.squeeze(1)
843
  else:
 
844
  attention_mask = attention_mask[:, -query_length:]
845
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
846
- return (query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k),
847
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k))
 
 
 
 
 
 
 
 
 
848
 
849
 
850
- # Attach attention classes in a dictionary for easy selection.
851
  ATTENTION_CLASSES = {
852
  "eager": DeepseekV3Attention,
853
  "flash_attention_2": DeepseekV3FlashAttention2,
@@ -865,15 +1105,27 @@ class DeepseekV3DecoderLayer(nn.Module):
865
  def __init__(self, config: DeepseekV3Config, layer_idx: int):
866
  super().__init__()
867
  self.hidden_size = config.hidden_size
868
- self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
869
- if (config.n_routed_experts is not None and
870
- layer_idx >= config.first_k_dense_replace and
871
- layer_idx % config.moe_layer_freq == 0):
 
 
 
 
 
 
 
872
  self.mlp = DeepseekV3MoE(config)
873
  else:
874
  self.mlp = DeepseekV3MLP(config)
875
- self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
876
- self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
 
 
 
877
 
878
  def forward(
879
  self,
@@ -886,10 +1138,14 @@ class DeepseekV3DecoderLayer(nn.Module):
886
  **kwargs
887
  ):
888
  """
889
- Forward pass for one Deepseek decoder layer.
890
  """
891
  residual = hidden_states
 
 
892
  hidden_states = self.input_layernorm(hidden_states)
 
 
893
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
894
  hidden_states=hidden_states,
895
  attention_mask=attention_mask,
@@ -901,21 +1157,21 @@ class DeepseekV3DecoderLayer(nn.Module):
901
  )
902
  hidden_states = residual + hidden_states
903
 
 
904
  residual = hidden_states
905
  hidden_states = self.post_attention_layernorm(hidden_states)
906
- # Dynamic Token Dropping
907
- importance = torch.sigmoid(nn.Linear(self.hidden_size, 1).to(hidden_states.device)(hidden_states))
908
- mask = (importance > (0.2 if self.training else 0.5)).float()
909
- hidden_states = hidden_states * mask + (1 - mask) * hidden_states.detach()
910
 
 
911
  hidden_states = self.mlp(hidden_states)
912
  hidden_states = residual + hidden_states
913
 
914
  outputs = (hidden_states,)
915
  if output_attentions:
916
  outputs += (self_attn_weights,)
 
917
  if use_cache:
918
  outputs += (present_key_value,)
 
919
  return outputs
920
 
921
 
@@ -925,7 +1181,7 @@ class DeepseekV3DecoderLayer(nn.Module):
925
 
926
  DeepseekV3_START_DOCSTRING = r"""
927
  This model inherits from `PreTrainedModel`. Check the superclass documentation
928
- for the generic methods the library implements for all its models (such as loading or saving, etc.)
929
  """
930
 
931
  class DeepseekV3PreTrainedModel(PreTrainedModel):
@@ -938,6 +1194,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
938
  _supports_cache_class = True
939
 
940
  def _init_weights(self, module):
 
941
  std = self.config.initializer_range
942
  if isinstance(module, nn.Linear):
943
  module.weight.data.normal_(mean=0.0, std=std)
@@ -960,32 +1217,41 @@ DeepseekV3_INPUTS_DOCSTRING = r"""
960
  input_ids (torch.LongTensor): shape `(batch_size, sequence_length)`
961
  attention_mask (torch.Tensor): shape `(batch_size, sequence_length)` or `(batch_size, 1, seq_len, seq_len)`, optional.
962
  position_ids (torch.LongTensor): shape `(batch_size, sequence_length)`, optional.
963
- past_key_values (Cache or tuple(tuple(torch.FloatTensor))): optional pre-computed key/value hidden-states.
 
964
  inputs_embeds (torch.FloatTensor): shape `(batch_size, sequence_length, hidden_size)`, optional.
965
- use_cache (bool), optional.
966
- output_attentions (bool), optional.
967
- output_hidden_states (bool), optional.
968
- return_dict (bool), optional.
969
  """
970
 
971
- @add_start_docstrings("The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", DeepseekV3_START_DOCSTRING)
 
 
 
972
  class DeepseekV3Model(DeepseekV3PreTrainedModel):
973
  """
974
- Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a DeepseekV3DecoderLayer.
975
  """
976
  def __init__(self, config: DeepseekV3Config):
977
  super().__init__(config)
978
  self.padding_idx = config.pad_token_id
979
  self.vocab_size = config.vocab_size
 
980
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
981
- self.layers = nn.ModuleList([DeepseekV3DecoderLayer(config, layer_idx)
982
- for layer_idx in range(config.num_hidden_layers)])
983
- self._use_flash_attention_2 = (config._attn_implementation == "flash_attention_2")
 
 
 
 
 
984
  self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
985
  self.gradient_checkpointing = False
986
  self.post_init()
987
- # Enable Torch 2.x compile for the forward pass.
988
- self.forward = torch.compile(self.forward, dynamic=True)
989
 
990
  def get_input_embeddings(self) -> nn.Embedding:
991
  return self.embed_tokens
@@ -1006,8 +1272,10 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
1006
  output_hidden_states: Optional[bool] = None,
1007
  return_dict: Optional[bool] = None,
1008
  ) -> Union[Tuple, BaseModelOutputWithPast]:
 
1009
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1010
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1011
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1012
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1013
 
@@ -1029,17 +1297,29 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
1029
  past_key_values_length = past_key_values.get_usable_length(seq_length)
1030
 
1031
  if position_ids is None:
1032
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1033
- position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
 
 
 
 
 
1034
  position_ids = position_ids.unsqueeze(0)
1035
 
1036
  if inputs_embeds is None:
1037
  inputs_embeds = self.embed_tokens(input_ids)
1038
 
 
1039
  if self._use_flash_attention_2:
1040
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1041
  else:
1042
- attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length)
 
 
 
 
 
 
1043
 
1044
  hidden_states = inputs_embeds
1045
 
@@ -1050,33 +1330,57 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
1050
  for idx, decoder_layer in enumerate(self.layers):
1051
  if output_hidden_states:
1052
  all_hidden_states += (hidden_states,)
 
 
1053
  if self.gradient_checkpointing and self.training:
1054
  def create_custom_forward(module):
1055
  def custom_forward(*inputs):
1056
  return module(*inputs, output_attentions=output_attentions, use_cache=use_cache)
1057
  return custom_forward
1058
- layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(decoder_layer),
1059
- hidden_states, attention_mask, position_ids, past_key_values)
 
 
 
 
 
1060
  else:
1061
- layer_outputs = decoder_layer(hidden_states,
1062
- attention_mask=attention_mask,
1063
- position_ids=position_ids,
1064
- past_key_value=past_key_values,
1065
- output_attentions=output_attentions,
1066
- use_cache=use_cache)
 
 
 
1067
  hidden_states = layer_outputs[0]
1068
  if use_cache:
1069
  next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 
1070
  if output_attentions:
1071
  all_self_attns += (layer_outputs[1],)
 
1072
  hidden_states = self.norm(hidden_states)
1073
  if output_hidden_states:
1074
  all_hidden_states += (hidden_states,)
1075
 
1076
- next_cache = next_decoder_cache.to_legacy_cache() if (use_cache and use_legacy_cache) else next_decoder_cache
 
 
 
 
 
 
 
1077
 
1078
  if not return_dict:
1079
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
 
1080
  return BaseModelOutputWithPast(
1081
  last_hidden_state=hidden_states,
1082
  past_key_values=next_cache,
@@ -1097,7 +1401,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1097
  self.model = DeepseekV3Model(config)
1098
  self.vocab_size = config.vocab_size
1099
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1100
- # The Knowledge Distillation head can be added here if needed.
1101
  self.post_init()
1102
 
1103
  def get_input_embeddings(self) -> nn.Embedding:
@@ -1119,7 +1423,9 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1119
  return self.model
1120
 
1121
  @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1122
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
 
1123
  def forward(
1124
  self,
1125
  input_ids: Optional[torch.LongTensor] = None,
@@ -1133,39 +1439,43 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1133
  output_hidden_states: Optional[bool] = None,
1134
  return_dict: Optional[bool] = None,
1135
  ) -> Union[Tuple, CausalLMOutputWithPast]:
 
 
 
 
 
1136
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1137
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1138
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1139
 
1140
- # FP8 mixed precision support if configured.
1141
- if self.config.mixed_precision == "fp8":
1142
- autocast_context = torch.autocast(device_type='cuda', dtype=torch.float8)
1143
- else:
1144
- autocast_context = torch.no_grad() # no-op context
1145
-
1146
- with autocast_context:
1147
- outputs = self.model(
1148
- input_ids=input_ids,
1149
- attention_mask=attention_mask,
1150
- position_ids=position_ids,
1151
- past_key_values=past_key_values,
1152
- inputs_embeds=inputs_embeds,
1153
- use_cache=use_cache,
1154
- output_attentions=output_attentions,
1155
- output_hidden_states=output_hidden_states,
1156
- return_dict=return_dict,
1157
- )
1158
  hidden_states = outputs[0]
1159
- logits = self.lm_head(hidden_states).float()
1160
- # Optionally, if you want to compute distillation logits:
1161
- # distill_logits = self.distill_head(hidden_states)
1162
  loss = None
1163
  if labels is not None:
 
1164
  shift_logits = logits[..., :-1, :].contiguous()
1165
  shift_labels = labels[..., 1:].contiguous()
 
1166
  loss_fct = CrossEntropyLoss()
1167
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1168
- shift_labels = shift_labels.view(-1).to(shift_logits.device)
 
1169
  loss = loss_fct(shift_logits, shift_labels)
1170
 
1171
  if not return_dict:
@@ -1188,6 +1498,9 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1188
  inputs_embeds: Optional[torch.FloatTensor] = None,
1189
  **kwargs
1190
  ):
 
 
 
1191
  if past_key_values is not None:
1192
  if isinstance(past_key_values, Cache):
1193
  cache_length = past_key_values.get_seq_length()
@@ -1196,37 +1509,50 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1196
  else:
1197
  cache_length = past_length = past_key_values[0][0].shape[2]
1198
  max_cache_length = None
 
1199
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1200
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
 
1201
  elif past_length < input_ids.shape[1]:
1202
  input_ids = input_ids[:, past_length:]
 
1203
  if max_cache_length is not None and attention_mask is not None:
1204
  if cache_length + input_ids.shape[1] > max_cache_length:
1205
  attention_mask = attention_mask[:, -max_cache_length:]
 
1206
  position_ids = kwargs.get("position_ids", None)
1207
  if attention_mask is not None and position_ids is None:
1208
  position_ids = attention_mask.long().cumsum(-1) - 1
1209
  position_ids.masked_fill_(attention_mask == 0, 1)
1210
  if past_key_values:
1211
- position_ids = position_ids[:, -input_ids.shape[1]:]
 
 
1212
  if inputs_embeds is not None and past_key_values is None:
1213
  model_inputs = {"inputs_embeds": inputs_embeds}
1214
  else:
1215
  model_inputs = {"input_ids": input_ids}
1216
- model_inputs.update({
1217
- "position_ids": position_ids,
1218
- "past_key_values": past_key_values,
1219
- "use_cache": kwargs.get("use_cache"),
1220
- "attention_mask": attention_mask,
1221
- })
 
 
 
1222
  return model_inputs
1223
 
1224
  @staticmethod
1225
  def _reorder_cache(past_key_values: Tuple, beam_idx: torch.Tensor) -> Tuple:
1226
  reordered_past = ()
1227
  for layer_past in past_key_values:
1228
- reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device))
1229
- for past_state in layer_past),)
 
 
 
 
1230
  return reordered_past
1231
 
1232
 
@@ -1247,6 +1573,7 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1247
  self.num_labels = config.num_labels
1248
  self.model = DeepseekV3Model(config)
1249
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
 
1250
  self.post_init()
1251
 
1252
  def get_input_embeddings(self) -> nn.Embedding:
@@ -1269,6 +1596,7 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1269
  output_hidden_states: Optional[bool] = None,
1270
  return_dict: Optional[bool] = None,
1271
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
 
1272
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1273
  transformer_outputs = self.model(
1274
  input_ids,
@@ -1289,17 +1617,24 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1289
  else:
1290
  batch_size = inputs_embeds.shape[0]
1291
 
 
1292
  if self.config.pad_token_id is None and batch_size != 1:
1293
- raise ValueError("Cannot handle batch sizes > 1 if no pad token is defined.")
 
 
1294
  if self.config.pad_token_id is None:
1295
  sequence_lengths = -1
1296
  else:
1297
  if input_ids is not None:
1298
- sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(logits.device)
 
 
1299
  else:
1300
  sequence_lengths = -1
1301
 
1302
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
1303
 
1304
  loss = None
1305
  if labels is not None:
@@ -1311,12 +1646,18 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1311
  self.config.problem_type = "single_label_classification"
1312
  else:
1313
  self.config.problem_type = "multi_label_classification"
 
1314
  if self.config.problem_type == "regression":
1315
  loss_fct = MSELoss()
1316
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) if self.num_labels == 1 else loss_fct(pooled_logits, labels)
 
 
 
1317
  elif self.config.problem_type == "single_label_classification":
1318
  loss_fct = CrossEntropyLoss()
1319
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
1320
  elif self.config.problem_type == "multi_label_classification":
1321
  loss_fct = BCEWithLogitsLoss()
1322
  loss = loss_fct(pooled_logits, labels)
 
54
 
55
  # If flash-attn is available
56
  if is_flash_attn_2_available():
57
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
58
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
59
 
60
  # This helps make `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
61
  if is_torch_fx_available():
62
  if not is_torch_greater_or_equal_than_1_13:
63
  import torch.fx
64
+
65
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
66
 
67
  _CONFIG_FOR_DOC = "DeepseekV3Config"
68
 
69
+
70
  # ==============================================================================
71
  # Rotary Embedding Helpers
72
  # ==============================================================================
 
82
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
83
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
84
  max_seqlen_in_batch = seqlens_in_batch.max().item()
85
+ # Build prefix sums
86
  cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
87
  return indices, cu_seqlens, max_seqlen_in_batch
88
 
89
+
90
  # ==============================================================================
91
  # Normalization Layers
92
  # ==============================================================================
 
102
  self.variance_epsilon = eps
103
 
104
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
105
+ # IMPROVEMENT: Provide type-safety & potential in-place usage
106
  input_dtype = hidden_states.dtype
107
  variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
108
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
109
  return (self.weight * hidden_states).to(input_dtype)
110
 
111
+
112
  ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)
113
 
114
+
115
  # ==============================================================================
116
  # Rotary Embeddings
117
  # ==============================================================================
 
132
  self.max_position_embeddings = max_position_embeddings
133
  self.base = base
134
 
135
+ inv_freq = 1.0 / (
136
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
137
+ )
138
  self.register_buffer("inv_freq", inv_freq, persistent=False)
139
+
140
+ # Build here to make `torch.jit.trace` work.
141
  self._set_cos_sin_cache(
142
  seq_len=max_position_embeddings,
143
  device=self.inv_freq.device,
144
  dtype=torch.get_default_dtype(),
145
  )
146
+ self.max_seq_len_cached = None
147
 
148
  def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
149
  self.max_seq_len_cached = seq_len
150
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
151
+
152
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
153
+ # Different from paper, but uses a different permutation to achieve the same effect
154
  emb = torch.cat((freqs, freqs), dim=-1)
155
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
156
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
 
159
  """
160
  x: [batch_size, num_heads, seq_len, head_size]
161
  """
162
+ if (self.max_seq_len_cached is None) or (seq_len and seq_len > self.max_seq_len_cached):
163
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
164
+
165
+ return (self.cos_cached[:seq_len].to(dtype=x.dtype),
166
+ self.sin_cached[:seq_len].to(dtype=x.dtype))
 
167
 
168
 
169
  class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
170
  """
171
+ RoPE extended with linear scaling. Credits to the Reddit user /u/kaiokendev
172
  """
173
  def __init__(
174
  self,
 
183
 
184
  def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
185
  self.max_seq_len_cached = seq_len
186
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
187
+ t = t / self.scaling_factor
188
  freqs = torch.outer(t, self.inv_freq)
189
  emb = torch.cat((freqs, freqs), dim=-1)
190
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
 
194
  class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
195
  """
196
  RoPE extended with Dynamic NTK scaling.
197
+ Credits to the Reddit users /u/bloc97 and /u/emozilla
198
  """
199
  def __init__(
200
  self,
 
209
 
210
  def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
211
  self.max_seq_len_cached = seq_len
212
+
213
  if seq_len > self.max_position_embeddings:
214
  base = self.base * (
215
  (self.scaling_factor * seq_len / self.max_position_embeddings)
216
  - (self.scaling_factor - 1)
217
  ) ** (self.dim / (self.dim - 2))
218
+ inv_freq = 1.0 / (
219
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
220
+ )
221
  self.register_buffer("inv_freq", inv_freq, persistent=False)
222
+
223
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
224
  freqs = torch.outer(t, self.inv_freq)
225
  emb = torch.cat((freqs, freqs), dim=-1)
226
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
227
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
228
 
229
 
230
+ # Extra Yarn-based formulas, as in your original code
231
  def yarn_find_correction_dim(
232
  num_rotations: float,
233
  dim: int,
234
  base: int = 10000,
235
  max_position_embeddings: int = 2048
236
  ):
237
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
238
+ 2 * math.log(base)
239
+ )
240
 
241
 
242
  def yarn_find_correction_range(
 
246
  base: int = 10000,
247
  max_position_embeddings: int = 2048
248
  ):
249
+ low = math.floor(
250
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
251
+ )
252
+ high = math.ceil(
253
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
254
+ )
255
+ # Clamped range
256
  return max(low, 0), min(high, dim - 1)
257
 
258
 
 
298
  def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
299
  self.max_seq_len_cached = seq_len
300
  dim = self.dim
301
+
302
+ freq_extra = 1.0 / (
303
+ self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
304
+ )
305
+ freq_inter = 1.0 / (
306
+ self.scaling_factor
307
+ * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
308
+ )
309
+
310
+ low, high = yarn_find_correction_range(
311
+ self.beta_fast,
312
+ self.beta_slow,
313
+ dim,
314
+ self.base,
315
+ self.original_max_position_embeddings,
316
+ )
317
  inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32)
318
  inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
319
  self.register_buffer("inv_freq", inv_freq, persistent=False)
320
+
321
  t = torch.arange(seq_len, device=device, dtype=torch.float32)
322
  freqs = torch.outer(t, inv_freq)
323
+ _mscale = float(
324
+ yarn_get_mscale(self.scaling_factor, self.mscale)
325
+ / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
326
+ )
327
+
328
  emb = torch.cat((freqs, freqs), dim=-1)
329
  self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False)
330
  self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
331
 
332
 
333
+ # ==============================================================================
334
  # General Rotary helper functions
335
  # ==============================================================================
336
 
 
380
  super().__init__()
381
  self.config = config
382
  self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
383
+ self.intermediate_size = (
384
+ config.intermediate_size if intermediate_size is None else intermediate_size
385
+ )
386
 
387
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
388
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
 
413
  self.norm_topk_prob = config.norm_topk_prob
414
  self.gating_dim = config.hidden_size
415
 
416
+ # Gating weight
417
  self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
418
+
419
  if self.topk_method == "noaux_tc":
420
+ self.e_score_correction_bias = nn.Parameter(
421
+ torch.empty((self.n_routed_experts))
422
+ )
423
+
424
  self.reset_parameters()
425
 
426
  def reset_parameters(self):
 
433
  Compute gating scores and select top-k experts.
434
  """
435
  bsz, seq_len, h = hidden_states.shape
436
+
437
+ # 1) Compute gating scores
438
  logits = F.linear(hidden_states.float(), self.weight.float(), None)
439
  if self.scoring_func == "sigmoid":
440
  scores = logits.sigmoid()
441
  else:
442
+ raise NotImplementedError(
443
+ f"Unsupported gating scoring function: {self.scoring_func}"
444
+ )
445
 
446
+ # 2) TopK selection
447
  if self.topk_method == "noaux_tc":
448
+ # This is a specialized approach from your original code
449
+ # IMPROVEMENT: Could consider generalizing to top2 gating or other advanced techniques
450
  scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
451
+ group_scores = (
452
+ scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
453
+ )
454
+ group_idx = torch.topk(
455
+ group_scores, k=self.topk_group, dim=-1, sorted=False
456
+ )[1] # [n, top_k_group]
457
  group_mask = torch.zeros_like(group_scores)
458
  group_mask.scatter_(1, group_idx, 1)
459
+ score_mask = group_mask.unsqueeze(-1).expand(
460
+ bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
461
+ ).reshape(bsz * seq_len, -1)
462
  tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
463
  _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
464
  topk_weight = scores_for_choice.gather(1, topk_idx)
465
  else:
466
+ raise NotImplementedError(
467
+ f"Unsupported topk_method: {self.topk_method}"
468
+ )
469
 
470
+ # 3) Norm gate to sum to 1 if top_k > 1
471
  if self.top_k > 1 and self.norm_topk_prob:
472
  denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
473
  topk_weight = topk_weight / denominator
474
 
475
+ # 4) Multiply scaling factor
476
  topk_weight = topk_weight * self.routed_scaling_factor
477
 
478
  return topk_idx, topk_weight
 
495
  self.experts_per_rank = config.n_routed_experts // config.ep_size
496
  self.ep_rank = dist.get_rank()
497
 
498
+ # Build experts
499
  experts_list = []
500
  for i in range(config.n_routed_experts):
501
+ # only build if belongs to current rank
502
  if self.ep_size > 1:
503
  if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank:
504
+ experts_list.append(
505
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
506
+ )
507
  else:
508
  experts_list.append(None)
509
  else:
510
+ experts_list.append(
511
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
512
+ )
513
  self.experts = nn.ModuleList(experts_list)
514
 
515
+ # Gate
516
  self.gate = MoEGate(config)
517
 
518
+ # Optionally shared experts
519
  if config.n_shared_experts is not None:
520
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
521
+ self.shared_experts = DeepseekV3MLP(
522
+ config=config, intermediate_size=intermediate_size
523
+ )
524
  else:
525
  self.shared_experts = None
526
 
527
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
528
  identity = hidden_states
529
  orig_shape = hidden_states.shape
530
+
531
  topk_idx, topk_weight = self.gate(hidden_states)
532
+ # Flatten
533
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
534
+
535
+ # Inference
536
  if not self.training:
537
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
538
  else:
539
+ # For training, you’d typically do a distributed MoE approach
540
+ # or a specialized approach from your original code.
541
+ # This placeholder just calls `moe_infer` for demonstration.
542
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
543
+
544
+ # Add shared experts if present
545
  if self.shared_experts is not None:
546
  y = y + self.shared_experts(identity)
547
+
548
  return y
549
 
550
  @torch.no_grad()
551
  def moe_infer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
552
  """
553
+ MoE inference path for each token. This code can be parallelized or distributed for better performance.
 
554
  """
555
  cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
556
  cnts.scatter_(1, topk_ids, 1)
 
559
  sorted_tokens = x[idxs // topk_ids.shape[1]]
560
  sorted_tokens_shape = sorted_tokens.shape
561
 
562
+ # Handle distribution if ep_size>1
563
  if self.ep_size > 1:
564
  tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
565
  tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
566
  dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
567
+ output_splits = (
568
+ tokens_per_expert_group.view(self.ep_size, self.experts_per_rank)
569
+ .sum(1)
570
+ .cpu()
571
+ .numpy()
572
+ .tolist()
573
+ )
574
+ gathered_tokens = sorted_tokens.new_empty(
575
+ tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
576
+ )
577
  input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
578
+ dist.all_to_all(
579
+ list(gathered_tokens.split(input_split_sizes)),
580
+ list(sorted_tokens.split(input_split_sizes)),
581
+ )
582
+ tokens_per_expert_post_gather = tokens_per_expert_group.view(
583
+ self.ep_size, self.experts_per_rank
584
+ ).sum(dim=0)
585
+ gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
586
  s = 0
587
  for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
588
  gatherd_idxs[s : s + k] = i % self.experts_per_rank
 
592
  tokens_per_expert = tokens_per_expert_post_gather
593
 
594
  tokens_per_expert = tokens_per_expert.cpu().numpy()
595
+
596
  outputs = []
597
  start_idx = 0
598
+ # Forward pass for each expert’s assigned tokens
599
  for i, num_tokens in enumerate(tokens_per_expert):
600
  end_idx = start_idx + num_tokens
601
  if num_tokens == 0:
 
605
  expert_out = expert(tokens_for_this_expert) if expert else tokens_for_this_expert
606
  outputs.append(expert_out)
607
  start_idx = end_idx
608
+
609
+ outs = (
610
+ torch.cat(outputs, dim=0)
611
+ if len(outputs)
612
+ else sorted_tokens.new_empty(0)
613
+ )
614
+
615
  if self.ep_size > 1:
616
  new_x = torch.empty_like(outs)
617
  new_x[gatherd_idxs] = outs
618
  gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
619
+ dist.all_to_all(
620
+ list(gathered_tokens.split(input_split_sizes)),
621
+ list(new_x.split(output_splits)),
622
+ )
623
  outs = gathered_tokens
624
+
625
  new_x = torch.empty_like(outs)
626
  new_x[idxs] = outs
627
+ final_out = (
628
+ new_x.view(*topk_ids.shape, -1)
629
+ .type(topk_weight.dtype)
630
+ .mul_(topk_weight.unsqueeze(dim=-1))
631
+ .sum(dim=1)
632
+ .type(new_x.dtype)
633
+ )
634
  return final_out
635
 
636
 
 
643
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
644
  if n_rep == 1:
645
  return hidden_states
646
+ hidden_states = hidden_states[:, :, None, :, :].expand(
647
+ batch, num_key_value_heads, n_rep, slen, head_dim
648
+ )
649
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
650
 
651
 
 
677
 
678
  self.is_causal = True
679
 
680
+ # Q-proj
681
  if self.q_lora_rank is None:
682
+ self.q_proj = nn.Linear(
683
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
684
+ )
685
  else:
686
+ self.q_a_proj = nn.Linear(
687
+ self.hidden_size, config.q_lora_rank, bias=config.attention_bias
688
+ )
689
  self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
690
+ self.q_b_proj = nn.Linear(
691
+ config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
692
+ )
693
 
694
+ # K,V-proj (MQA style)
695
+ self.kv_a_proj_with_mqa = nn.Linear(
696
+ self.hidden_size,
697
+ config.kv_lora_rank + config.qk_rope_head_dim,
698
+ bias=config.attention_bias,
699
+ )
700
  self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
701
+ self.kv_b_proj = nn.Linear(
702
+ config.kv_lora_rank,
703
+ self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
704
+ bias=False,
705
+ )
706
 
707
+ # Out proj
708
  self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias)
709
 
710
+ # Build the rotary embedding
711
  self._init_rope()
712
 
713
+ # IMPROVEMENT: Custom softmax scaling, adapt for Yarn scaling
714
  self.softmax_scale = self.q_head_dim ** (-0.5)
715
  if self.config.rope_scaling is not None:
716
+ # E.g. yarn-based scaling can factor in additional multipliers
717
  mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
718
  scaling_factor = self.config.rope_scaling["factor"]
719
  if mscale_all_dim:
720
+ # Simple example using the Yarn approach
721
  self.softmax_scale *= yarn_get_mscale(scaling_factor, mscale_all_dim) ** 2
722
 
723
  def _init_rope(self):
724
+ """
725
+ Initializes RoPE depending on scaling type: linear, dynamic, yarn, etc.
726
+ """
727
  if self.config.rope_scaling is None:
728
+ self.rotary_emb = DeepseekV3RotaryEmbedding(
729
+ self.qk_rope_head_dim,
730
+ max_position_embeddings=self.max_position_embeddings,
731
+ base=self.rope_theta,
732
+ )
733
  else:
734
  scaling_type = self.config.rope_scaling["type"]
735
  scaling_factor = self.config.rope_scaling["factor"]
736
+
737
  if scaling_type == "linear":
738
+ self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
739
+ self.qk_rope_head_dim,
740
+ max_position_embeddings=self.max_position_embeddings,
741
+ scaling_factor=scaling_factor,
742
+ base=self.rope_theta,
743
+ )
744
  elif scaling_type == "dynamic":
745
+ self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(
746
+ self.qk_rope_head_dim,
747
+ max_position_embeddings=self.max_position_embeddings,
748
+ scaling_factor=scaling_factor,
749
+ base=self.rope_theta,
750
+ )
751
  elif scaling_type == "yarn":
752
+ kwargs = {
753
+ key: self.config.rope_scaling[key]
754
+ for key in [
755
+ "original_max_position_embeddings",
756
+ "beta_fast",
757
+ "beta_slow",
758
+ "mscale",
759
+ "mscale_all_dim",
760
+ ]
761
+ if key in self.config.rope_scaling
762
+ }
763
+ self.rotary_emb = DeepseekV3YarnRotaryEmbedding(
764
+ self.qk_rope_head_dim,
765
+ max_position_embeddings=self.max_position_embeddings,
766
+ scaling_factor=scaling_factor,
767
+ base=self.rope_theta,
768
+ **kwargs,
769
+ )
770
  else:
771
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
772
 
 
780
  use_cache: bool = False,
781
  **kwargs,
782
  ):
783
+ """
784
+ Standard forward pass for multi-headed self-attention.
785
+ """
786
  if "padding_mask" in kwargs:
787
+ warnings.warn(
788
+ "Passing `padding_mask` is deprecated. Use `attention_mask` instead."
789
+ )
790
+
791
  bsz, q_len, _ = hidden_states.size()
792
 
793
+ # Q projection
794
  if self.q_lora_rank is None:
795
  q = self.q_proj(hidden_states)
796
  else:
 
798
  q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
799
  q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
800
 
801
+ # MQA: K,V from single projection
802
  compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
803
+ compressed_kv, k_pe = torch.split(
804
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
805
+ )
806
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
807
+ kv = (
808
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
809
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
810
+ .transpose(1, 2)
811
+ )
812
+ k_nope, value_states = torch.split(
813
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
814
+ )
815
  kv_seq_len = value_states.shape[-2]
816
  if past_key_value is not None:
817
  if self.layer_idx is None:
818
+ raise ValueError(
819
+ f"Missing `layer_idx` for caching. Provide layer_idx in {self.__class__.__name__}."
820
+ )
821
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
822
 
823
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
824
+
825
+ # Apply rotary to query and key
826
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
827
 
828
  query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
829
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
830
  query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
831
 
832
  key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
833
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
834
  key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
835
 
836
  if past_key_value is not None:
837
+ cache_kwargs = {"sin": sin, "cos": cos} # for RoPE
838
+ key_states, value_states = past_key_value.update(
839
+ key_states, value_states, self.layer_idx, cache_kwargs
840
+ )
841
+
842
+ attn_weights = (torch.matmul(query_states, key_states.transpose(2, 3))
843
+ * self.softmax_scale)
844
 
 
845
  if attention_mask is not None:
846
  attn_weights = attn_weights + attention_mask
847
+
848
+ # Use float32 for more stable softmax
849
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
850
+ attn_weights = nn.functional.dropout(
851
+ attn_weights, p=self.attention_dropout, training=self.training
852
+ )
853
  attn_output = torch.matmul(attn_weights, value_states)
854
+
855
+ attn_output = attn_output.transpose(1, 2).contiguous()
856
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
857
+
858
  attn_output = self.o_proj(attn_output)
859
 
860
  if not output_attentions:
861
  attn_weights = None
862
+
863
  return attn_output, attn_weights, past_key_value
864
 
865
 
866
  class DeepseekV3FlashAttention2(DeepseekV3Attention):
867
  """
868
+ DeepseekV3 flash attention module. Inherits the same Q/K/V projections from DeepseekV3Attention.
869
+ Only the forward pass changes to use flash_attn APIs.
870
  """
871
  def __init__(self, *args, **kwargs):
872
  super().__init__(*args, **kwargs)
 
882
  use_cache: bool = False,
883
  **kwargs,
884
  ):
885
+ # Overridden forward logic using flash attention
886
  if "padding_mask" in kwargs:
887
+ warnings.warn(
888
+ "Passing `padding_mask` is deprecated. Use `attention_mask` instead."
889
+ )
890
  attention_mask = kwargs.pop("padding_mask")
891
 
892
+ output_attentions = False # flash attn 2 doesn't expose attention probs
893
 
894
  bsz, q_len, _ = hidden_states.shape
895
  if self.q_lora_rank is None:
 
900
  q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
901
 
902
  compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
903
+ compressed_kv, k_pe = torch.split(
904
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
905
+ )
906
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
907
+ kv = (
908
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
909
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
910
+ .transpose(1, 2)
911
+ )
912
+ k_nope, value_states = torch.split(
913
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
914
+ )
915
  kv_seq_len = value_states.shape[-2]
916
  if past_key_value is not None:
917
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
918
+
919
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
920
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
921
 
922
  query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
923
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
924
  query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
925
 
926
  key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
927
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
928
  key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
929
 
930
  if self.q_head_dim != self.v_head_dim:
931
+ # Pad if needed
932
  value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
933
 
934
  if past_key_value is not None:
935
+ cache_kwargs = {"sin": sin, "cos": cos}
936
+ key_states, value_states = past_key_value.update(
937
+ key_states, value_states, self.layer_idx, cache_kwargs
938
+ )
939
 
940
+ # Prepare for flash-attn which needs [bsz, seqlen, n_heads, head_dim]
941
  query_states = query_states.transpose(1, 2)
942
  key_states = key_states.transpose(1, 2)
943
  value_states = value_states.transpose(1, 2)
944
 
945
  dropout_rate = self.attention_dropout if self.training else 0.0
946
+
947
+ # Possibly revert to original Q,K,V dtype if upcast to float32
948
  input_dtype = query_states.dtype
949
  if input_dtype == torch.float32:
950
+ # Attempt to revert to original param dtype if different
951
+ target_dtype = (
952
+ self.q_proj.weight.dtype
953
+ if self.q_lora_rank is None
954
+ else self.q_a_proj.weight.dtype
955
+ )
956
  query_states = query_states.to(target_dtype)
957
  key_states = key_states.to(target_dtype)
958
  value_states = value_states.to(target_dtype)
959
 
960
+ # Flash attention pass
961
  attn_output = self._flash_attention_forward(
962
  query_states,
963
  key_states,
 
969
  )
970
 
971
  if self.q_head_dim != self.v_head_dim:
972
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
973
 
974
+ # [bsz, seqlen, n_heads, head_dim]
975
  attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
976
  attn_output = self.o_proj(attn_output)
977
+
978
  return attn_output, None, past_key_value
979
 
980
  def _flash_attention_forward(
 
987
  dropout: float = 0.0,
988
  softmax_scale: Optional[float] = None,
989
  ) -> torch.Tensor:
990
+ """
991
+ Wraps the flash-attn calls. If attention_mask has padding, we unpad first.
992
+ """
993
  if not self._flash_attn_uses_top_left_mask:
994
  causal = self.is_causal
995
  else:
996
+ # For flash_attn<2.1.0
997
  causal = self.is_causal and query_length != 1
998
 
999
  if attention_mask is not None:
 
1003
  value_states,
1004
  indices_q,
1005
  (cu_seqlens_q, cu_seqlens_k),
1006
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k)) = self._upad_input(
1007
+ query_states, key_states, value_states, attention_mask, query_length
1008
+ )
1009
  attn_output_unpad = flash_attn_varlen_func(
1010
  query_states,
1011
  key_states,
 
1018
  softmax_scale=softmax_scale,
1019
  causal=causal,
1020
  )
1021
+ attn_output = pad_input(
1022
+ attn_output_unpad, indices_q, batch_size, query_length
1023
+ )
1024
  else:
1025
  attn_output = flash_attn_func(
1026
  query_states,
 
1030
  softmax_scale=softmax_scale,
1031
  causal=causal,
1032
  )
1033
+
1034
  return attn_output
1035
 
1036
  def _upad_input(
 
1041
  attention_mask: torch.Tensor,
1042
  query_length: int,
1043
  ):
1044
+ """
1045
+ Unpads the Q, K, and V for FlashAttention in variable-length mode.
1046
+ """
1047
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1048
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1049
 
1050
+ key_layer = index_first_axis(
1051
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1052
+ indices_k,
1053
+ )
1054
+ value_layer = index_first_axis(
1055
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1056
+ indices_k,
1057
+ )
1058
  if query_length == kv_seq_len:
1059
+ query_layer = index_first_axis(
1060
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
1061
+ indices_k,
1062
+ )
1063
  cu_seqlens_q = cu_seqlens_k
1064
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
1065
  indices_q = indices_k
1066
  elif query_length == 1:
1067
  max_seqlen_in_batch_q = 1
1068
+ cu_seqlens_q = torch.arange(
1069
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1070
+ )
1071
  indices_q = cu_seqlens_q[:-1]
1072
  query_layer = query_layer.squeeze(1)
1073
  else:
1074
+ # handle partial left padding
1075
  attention_mask = attention_mask[:, -query_length:]
1076
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1077
+ query_layer, attention_mask
1078
+ )
1079
+
1080
+ return (
1081
+ query_layer,
1082
+ key_layer,
1083
+ value_layer,
1084
+ indices_q,
1085
+ (cu_seqlens_q, cu_seqlens_k),
1086
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1087
+ )
1088
 
1089
 
1090
+ # Attach the attention classes in a dictionary for easy selection
1091
  ATTENTION_CLASSES = {
1092
  "eager": DeepseekV3Attention,
1093
  "flash_attention_2": DeepseekV3FlashAttention2,
 
1105
  def __init__(self, config: DeepseekV3Config, layer_idx: int):
1106
  super().__init__()
1107
  self.hidden_size = config.hidden_size
1108
+
1109
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
1110
+ config=config, layer_idx=layer_idx
1111
+ )
1112
+
1113
+ # Optionally use MoE
1114
+ if (
1115
+ config.n_routed_experts is not None
1116
+ and layer_idx >= config.first_k_dense_replace
1117
+ and layer_idx % config.moe_layer_freq == 0
1118
+ ):
1119
  self.mlp = DeepseekV3MoE(config)
1120
  else:
1121
  self.mlp = DeepseekV3MLP(config)
1122
+
1123
+ self.input_layernorm = DeepseekV3RMSNorm(
1124
+ config.hidden_size, eps=config.rms_norm_eps
1125
+ )
1126
+ self.post_attention_layernorm = DeepseekV3RMSNorm(
1127
+ config.hidden_size, eps=config.rms_norm_eps
1128
+ )
1129
 
1130
  def forward(
1131
  self,
 
1138
  **kwargs
1139
  ):
1140
  """
1141
+ Forward pass for one Deepseek decoder layer.
1142
  """
1143
  residual = hidden_states
1144
+
1145
+ # Pre-attention norm
1146
  hidden_states = self.input_layernorm(hidden_states)
1147
+
1148
+ # Self-attention
1149
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
1150
  hidden_states=hidden_states,
1151
  attention_mask=attention_mask,
 
1157
  )
1158
  hidden_states = residual + hidden_states
1159
 
1160
+ # Post-attention norm
1161
  residual = hidden_states
1162
  hidden_states = self.post_attention_layernorm(hidden_states)
 
 
 
 
1163
 
1164
+ # MLP or MoE
1165
  hidden_states = self.mlp(hidden_states)
1166
  hidden_states = residual + hidden_states
1167
 
1168
  outputs = (hidden_states,)
1169
  if output_attentions:
1170
  outputs += (self_attn_weights,)
1171
+
1172
  if use_cache:
1173
  outputs += (present_key_value,)
1174
+
1175
  return outputs
1176
 
1177
 
 
1181
 
1182
  DeepseekV3_START_DOCSTRING = r"""
1183
  This model inherits from `PreTrainedModel`. Check the superclass documentation
1184
+ for the generic methods the library implements for all its model (such as loading or saving, etc.)
1185
  """
1186
 
1187
  class DeepseekV3PreTrainedModel(PreTrainedModel):
 
1194
  _supports_cache_class = True
1195
 
1196
  def _init_weights(self, module):
1197
+ # IMPROVEMENT: Could add more robust initialization or variants (e.g., Xavier)
1198
  std = self.config.initializer_range
1199
  if isinstance(module, nn.Linear):
1200
  module.weight.data.normal_(mean=0.0, std=std)
 
1217
  input_ids (torch.LongTensor): shape `(batch_size, sequence_length)`
1218
  attention_mask (torch.Tensor): shape `(batch_size, sequence_length)` or `(batch_size, 1, seq_len, seq_len)`, optional.
1219
  position_ids (torch.LongTensor): shape `(batch_size, sequence_length)`, optional.
1220
+ past_key_values (Cache or tuple(tuple(torch.FloatTensor))), optional:
1221
+ Pre-computed hidden-states (key and values) that can be used to speed up sequential decoding.
1222
  inputs_embeds (torch.FloatTensor): shape `(batch_size, sequence_length, hidden_size)`, optional.
1223
+ use_cache (bool), optional
1224
+ output_attentions (bool), optional
1225
+ output_hidden_states (bool), optional
1226
+ return_dict (bool), optional
1227
  """
1228
 
1229
+ @add_start_docstrings(
1230
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1231
+ DeepseekV3_START_DOCSTRING,
1232
+ )
1233
  class DeepseekV3Model(DeepseekV3PreTrainedModel):
1234
  """
1235
+ Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a `DeepseekV3DecoderLayer`.
1236
  """
1237
  def __init__(self, config: DeepseekV3Config):
1238
  super().__init__(config)
1239
  self.padding_idx = config.pad_token_id
1240
  self.vocab_size = config.vocab_size
1241
+
1242
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1243
+
1244
+ # Build decoder layers
1245
+ self.layers = nn.ModuleList([
1246
+ DeepseekV3DecoderLayer(config, layer_idx)
1247
+ for layer_idx in range(config.num_hidden_layers)
1248
+ ])
1249
+
1250
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1251
  self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1252
+
1253
  self.gradient_checkpointing = False
1254
  self.post_init()
 
 
1255
 
1256
  def get_input_embeddings(self) -> nn.Embedding:
1257
  return self.embed_tokens
 
1272
  output_hidden_states: Optional[bool] = None,
1273
  return_dict: Optional[bool] = None,
1274
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1275
+
1276
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1277
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None
1278
+ else self.config.output_hidden_states)
1279
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1280
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1281
 
 
1297
  past_key_values_length = past_key_values.get_usable_length(seq_length)
1298
 
1299
  if position_ids is None:
1300
+ device = (input_ids.device if input_ids is not None else inputs_embeds.device)
1301
+ position_ids = torch.arange(
1302
+ past_key_values_length,
1303
+ seq_length + past_key_values_length,
1304
+ dtype=torch.long,
1305
+ device=device
1306
+ )
1307
  position_ids = position_ids.unsqueeze(0)
1308
 
1309
  if inputs_embeds is None:
1310
  inputs_embeds = self.embed_tokens(input_ids)
1311
 
1312
+ # If flash attention is used, we pass 2D mask to the layers
1313
  if self._use_flash_attention_2:
1314
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1315
  else:
1316
+ # standard 4D mask
1317
+ attention_mask = _prepare_4d_causal_attention_mask(
1318
+ attention_mask,
1319
+ (batch_size, seq_length),
1320
+ inputs_embeds,
1321
+ past_key_values_length,
1322
+ )
1323
 
1324
  hidden_states = inputs_embeds
1325
 
 
1330
  for idx, decoder_layer in enumerate(self.layers):
1331
  if output_hidden_states:
1332
  all_hidden_states += (hidden_states,)
1333
+
1334
+ # Potential gradient checkpointing
1335
  if self.gradient_checkpointing and self.training:
1336
  def create_custom_forward(module):
1337
  def custom_forward(*inputs):
1338
  return module(*inputs, output_attentions=output_attentions, use_cache=use_cache)
1339
  return custom_forward
1340
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1341
+ create_custom_forward(decoder_layer),
1342
+ hidden_states,
1343
+ attention_mask,
1344
+ position_ids,
1345
+ past_key_values
1346
+ )
1347
  else:
1348
+ layer_outputs = decoder_layer(
1349
+ hidden_states,
1350
+ attention_mask=attention_mask,
1351
+ position_ids=position_ids,
1352
+ past_key_value=past_key_values,
1353
+ output_attentions=output_attentions,
1354
+ use_cache=use_cache,
1355
+ )
1356
+
1357
  hidden_states = layer_outputs[0]
1358
  if use_cache:
1359
  next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1360
+
1361
  if output_attentions:
1362
  all_self_attns += (layer_outputs[1],)
1363
+
1364
  hidden_states = self.norm(hidden_states)
1365
  if output_hidden_states:
1366
  all_hidden_states += (hidden_states,)
1367
 
1368
+ # Prepare next_cache
1369
+ next_cache = None
1370
+ if use_cache:
1371
+ next_cache = (
1372
+ next_decoder_cache.to_legacy_cache()
1373
+ if use_legacy_cache
1374
+ else next_decoder_cache
1375
+ )
1376
 
1377
  if not return_dict:
1378
+ return tuple(
1379
+ v
1380
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1381
+ if v is not None
1382
+ )
1383
+
1384
  return BaseModelOutputWithPast(
1385
  last_hidden_state=hidden_states,
1386
  past_key_values=next_cache,
 
1401
  self.model = DeepseekV3Model(config)
1402
  self.vocab_size = config.vocab_size
1403
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1404
+
1405
  self.post_init()
1406
 
1407
  def get_input_embeddings(self) -> nn.Embedding:
 
1423
  return self.model
1424
 
1425
  @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1426
+ @replace_return_docstrings(
1427
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1428
+ )
1429
  def forward(
1430
  self,
1431
  input_ids: Optional[torch.LongTensor] = None,
 
1439
  output_hidden_states: Optional[bool] = None,
1440
  return_dict: Optional[bool] = None,
1441
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1442
+ """
1443
+ Args:
1444
+ labels (torch.LongTensor of shape (batch_size, sequence_length), optional):
1445
+ For computing the language modeling loss. Indices in [0, config.vocab_size] or -100.
1446
+ """
1447
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1448
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None
1449
+ else self.config.output_hidden_states)
1450
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1451
 
1452
+ # Decoder forward
1453
+ outputs = self.model(
1454
+ input_ids=input_ids,
1455
+ attention_mask=attention_mask,
1456
+ position_ids=position_ids,
1457
+ past_key_values=past_key_values,
1458
+ inputs_embeds=inputs_embeds,
1459
+ use_cache=use_cache,
1460
+ output_attentions=output_attentions,
1461
+ output_hidden_states=output_hidden_states,
1462
+ return_dict=return_dict,
1463
+ )
1464
+
 
 
 
 
 
1465
  hidden_states = outputs[0]
1466
+ logits = self.lm_head(hidden_states)
1467
+ logits = logits.float() # IMPROVEMENT: Could keep FP16 if stable
1468
+
1469
  loss = None
1470
  if labels is not None:
1471
+ # SHIFT
1472
  shift_logits = logits[..., :-1, :].contiguous()
1473
  shift_labels = labels[..., 1:].contiguous()
1474
+
1475
  loss_fct = CrossEntropyLoss()
1476
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1477
+ shift_labels = shift_labels.view(-1)
1478
+ shift_labels = shift_labels.to(shift_logits.device)
1479
  loss = loss_fct(shift_logits, shift_labels)
1480
 
1481
  if not return_dict:
 
1498
  inputs_embeds: Optional[torch.FloatTensor] = None,
1499
  **kwargs
1500
  ):
1501
+ """
1502
+ Prepare inputs during generation loops.
1503
+ """
1504
  if past_key_values is not None:
1505
  if isinstance(past_key_values, Cache):
1506
  cache_length = past_key_values.get_seq_length()
 
1509
  else:
1510
  cache_length = past_length = past_key_values[0][0].shape[2]
1511
  max_cache_length = None
1512
+
1513
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1514
+ # match up with the unprocessed tokens
1515
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1516
  elif past_length < input_ids.shape[1]:
1517
  input_ids = input_ids[:, past_length:]
1518
+
1519
  if max_cache_length is not None and attention_mask is not None:
1520
  if cache_length + input_ids.shape[1] > max_cache_length:
1521
  attention_mask = attention_mask[:, -max_cache_length:]
1522
+
1523
  position_ids = kwargs.get("position_ids", None)
1524
  if attention_mask is not None and position_ids is None:
1525
  position_ids = attention_mask.long().cumsum(-1) - 1
1526
  position_ids.masked_fill_(attention_mask == 0, 1)
1527
  if past_key_values:
1528
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1529
+
1530
+ # If we have inputs_embeds only for the first token
1531
  if inputs_embeds is not None and past_key_values is None:
1532
  model_inputs = {"inputs_embeds": inputs_embeds}
1533
  else:
1534
  model_inputs = {"input_ids": input_ids}
1535
+
1536
+ model_inputs.update(
1537
+ {
1538
+ "position_ids": position_ids,
1539
+ "past_key_values": past_key_values,
1540
+ "use_cache": kwargs.get("use_cache"),
1541
+ "attention_mask": attention_mask,
1542
+ }
1543
+ )
1544
  return model_inputs
1545
 
1546
  @staticmethod
1547
  def _reorder_cache(past_key_values: Tuple, beam_idx: torch.Tensor) -> Tuple:
1548
  reordered_past = ()
1549
  for layer_past in past_key_values:
1550
+ reordered_past += (
1551
+ tuple(
1552
+ past_state.index_select(0, beam_idx.to(past_state.device))
1553
+ for past_state in layer_past
1554
+ ),
1555
+ )
1556
  return reordered_past
1557
 
1558
 
 
1573
  self.num_labels = config.num_labels
1574
  self.model = DeepseekV3Model(config)
1575
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1576
+
1577
  self.post_init()
1578
 
1579
  def get_input_embeddings(self) -> nn.Embedding:
 
1596
  output_hidden_states: Optional[bool] = None,
1597
  return_dict: Optional[bool] = None,
1598
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1599
+
1600
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1601
  transformer_outputs = self.model(
1602
  input_ids,
 
1617
  else:
1618
  batch_size = inputs_embeds.shape[0]
1619
 
1620
+ # If no pad_token_id, assume last token for each sample
1621
  if self.config.pad_token_id is None and batch_size != 1:
1622
+ raise ValueError(
1623
+ "Cannot handle batch sizes > 1 if no pad token is defined."
1624
+ )
1625
  if self.config.pad_token_id is None:
1626
  sequence_lengths = -1
1627
  else:
1628
  if input_ids is not None:
1629
+ sequence_lengths = (
1630
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1631
+ ).to(logits.device)
1632
  else:
1633
  sequence_lengths = -1
1634
 
1635
+ pooled_logits = logits[
1636
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1637
+ ]
1638
 
1639
  loss = None
1640
  if labels is not None:
 
1646
  self.config.problem_type = "single_label_classification"
1647
  else:
1648
  self.config.problem_type = "multi_label_classification"
1649
+
1650
  if self.config.problem_type == "regression":
1651
  loss_fct = MSELoss()
1652
+ if self.num_labels == 1:
1653
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1654
+ else:
1655
+ loss = loss_fct(pooled_logits, labels)
1656
  elif self.config.problem_type == "single_label_classification":
1657
  loss_fct = CrossEntropyLoss()
1658
+ loss = loss_fct(
1659
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1660
+ )
1661
  elif self.config.problem_type == "multi_label_classification":
1662
  loss_fct = BCEWithLogitsLoss()
1663
  loss = loss_fct(pooled_logits, labels)