wuhp commited on
Commit
16ee998
·
verified ·
1 Parent(s): d4a6983

Create modeling_deepseek.py

Browse files
Files changed (1) hide show
  1. myr1/modeling_deepseek.py +1675 -0
myr1/modeling_deepseek.py ADDED
@@ -0,0 +1,1675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_deepseek.py
3
+
4
+ An improved version of the DeepSeekV3 model code with added docstrings, in-line commentary,
5
+ some mild refactoring, and suggestions for potential future enhancements. This version is
6
+ intended for demonstration and testing. Actual performance gains may vary based on your
7
+ environment and training data.
8
+ """
9
+
10
+ import math
11
+ import warnings
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torch.utils.checkpoint
17
+ from torch import nn
18
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
19
+
20
+ from transformers.activations import ACT2FN
21
+ from transformers.cache_utils import Cache, DynamicCache
22
+ from transformers.modeling_attn_mask_utils import (
23
+ AttentionMaskConverter,
24
+ _prepare_4d_attention_mask,
25
+ _prepare_4d_causal_attention_mask,
26
+ )
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPast,
29
+ CausalLMOutputWithPast,
30
+ SequenceClassifierOutputWithPast,
31
+ )
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.pytorch_utils import (
34
+ ALL_LAYERNORM_LAYERS,
35
+ is_torch_greater_or_equal_than_1_13,
36
+ )
37
+ from transformers.utils import (
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ is_flash_attn_2_available,
41
+ is_flash_attn_greater_or_equal_2_10,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from transformers.utils.import_utils import is_torch_fx_available
46
+
47
+ # Import your configuration
48
+ from .configuration_deepseek import DeepseekV3Config
49
+
50
+ import torch.distributed as dist
51
+ import numpy as np
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+ # If flash-attn is available
56
+ if is_flash_attn_2_available():
57
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
58
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
59
+
60
+ # This helps make `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
61
+ if is_torch_fx_available():
62
+ if not is_torch_greater_or_equal_than_1_13:
63
+ import torch.fx
64
+
65
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
66
+
67
+ _CONFIG_FOR_DOC = "DeepseekV3Config"
68
+
69
+
70
+ # ==============================================================================
71
+ # Rotary Embedding Helpers
72
+ # ==============================================================================
73
+
74
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
75
+ """
76
+ Prepares unpadded data indices for varlen attention.
77
+ Returns:
78
+ indices: Flattened indices where mask=1.
79
+ cu_seqlens: prefix-summed lengths for each sequence.
80
+ max_seqlen_in_batch: maximum sequence length in the batch.
81
+ """
82
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
83
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
84
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
85
+ # Build prefix sums
86
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
87
+ return indices, cu_seqlens, max_seqlen_in_batch
88
+
89
+
90
+ # ==============================================================================
91
+ # Normalization Layers
92
+ # ==============================================================================
93
+
94
+ class DeepseekV3RMSNorm(nn.Module):
95
+ """
96
+ DeepseekV3RMSNorm is essentially a Root Mean Square Layer Normalization.
97
+ This can be more stable than standard LayerNorm in some scenarios.
98
+ """
99
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
100
+ super().__init__()
101
+ self.weight = nn.Parameter(torch.ones(hidden_size))
102
+ self.variance_epsilon = eps
103
+
104
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
105
+ # IMPROVEMENT: Provide type-safety & potential in-place usage
106
+ input_dtype = hidden_states.dtype
107
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
108
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
109
+ return (self.weight * hidden_states).to(input_dtype)
110
+
111
+
112
+ ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)
113
+
114
+
115
+ # ==============================================================================
116
+ # Rotary Embeddings
117
+ # ==============================================================================
118
+
119
+ class DeepseekV3RotaryEmbedding(nn.Module):
120
+ """
121
+ Base Rotary Position Embedding for the Deepseek architecture.
122
+ """
123
+ def __init__(
124
+ self,
125
+ dim: int,
126
+ max_position_embeddings: int = 2048,
127
+ base: int = 10000,
128
+ device: Optional[torch.device] = None
129
+ ):
130
+ super().__init__()
131
+ self.dim = dim
132
+ self.max_position_embeddings = max_position_embeddings
133
+ self.base = base
134
+
135
+ inv_freq = 1.0 / (
136
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
137
+ )
138
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
139
+
140
+ # Build here to make `torch.jit.trace` work.
141
+ self._set_cos_sin_cache(
142
+ seq_len=max_position_embeddings,
143
+ device=self.inv_freq.device,
144
+ dtype=torch.get_default_dtype(),
145
+ )
146
+ self.max_seq_len_cached = None
147
+
148
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
149
+ self.max_seq_len_cached = seq_len
150
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
151
+
152
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
153
+ # Different from paper, but uses a different permutation to achieve the same effect
154
+ emb = torch.cat((freqs, freqs), dim=-1)
155
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
156
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
157
+
158
+ def forward(self, x: torch.Tensor, seq_len: Optional[int] = None):
159
+ """
160
+ x: [batch_size, num_heads, seq_len, head_size]
161
+ """
162
+ if (self.max_seq_len_cached is None) or (seq_len and seq_len > self.max_seq_len_cached):
163
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
164
+
165
+ return (self.cos_cached[:seq_len].to(dtype=x.dtype),
166
+ self.sin_cached[:seq_len].to(dtype=x.dtype))
167
+
168
+
169
+ class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
170
+ """
171
+ RoPE extended with linear scaling. Credits to the Reddit user /u/kaiokendev
172
+ """
173
+ def __init__(
174
+ self,
175
+ dim: int,
176
+ max_position_embeddings: int = 2048,
177
+ base: int = 10000,
178
+ device: Optional[torch.device] = None,
179
+ scaling_factor: float = 1.0
180
+ ):
181
+ self.scaling_factor = scaling_factor
182
+ super().__init__(dim, max_position_embeddings, base, device)
183
+
184
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
185
+ self.max_seq_len_cached = seq_len
186
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
187
+ t = t / self.scaling_factor
188
+ freqs = torch.outer(t, self.inv_freq)
189
+ emb = torch.cat((freqs, freqs), dim=-1)
190
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
191
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
192
+
193
+
194
+ class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
195
+ """
196
+ RoPE extended with Dynamic NTK scaling.
197
+ Credits to the Reddit users /u/bloc97 and /u/emozilla
198
+ """
199
+ def __init__(
200
+ self,
201
+ dim: int,
202
+ max_position_embeddings: int = 2048,
203
+ base: int = 10000,
204
+ device: Optional[torch.device] = None,
205
+ scaling_factor: float = 1.0
206
+ ):
207
+ self.scaling_factor = scaling_factor
208
+ super().__init__(dim, max_position_embeddings, base, device)
209
+
210
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
211
+ self.max_seq_len_cached = seq_len
212
+
213
+ if seq_len > self.max_position_embeddings:
214
+ base = self.base * (
215
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
216
+ - (self.scaling_factor - 1)
217
+ ) ** (self.dim / (self.dim - 2))
218
+ inv_freq = 1.0 / (
219
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
220
+ )
221
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
222
+
223
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
224
+ freqs = torch.outer(t, self.inv_freq)
225
+ emb = torch.cat((freqs, freqs), dim=-1)
226
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
227
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
228
+
229
+
230
+ # Extra Yarn-based formulas, as in your original code
231
+ def yarn_find_correction_dim(
232
+ num_rotations: float,
233
+ dim: int,
234
+ base: int = 10000,
235
+ max_position_embeddings: int = 2048
236
+ ):
237
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
238
+ 2 * math.log(base)
239
+ )
240
+
241
+
242
+ def yarn_find_correction_range(
243
+ low_rot: float,
244
+ high_rot: float,
245
+ dim: int,
246
+ base: int = 10000,
247
+ max_position_embeddings: int = 2048
248
+ ):
249
+ low = math.floor(
250
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
251
+ )
252
+ high = math.ceil(
253
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
254
+ )
255
+ # Clamped range
256
+ return max(low, 0), min(high, dim - 1)
257
+
258
+
259
+ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
260
+ if scale <= 1:
261
+ return 1.0
262
+ return 0.1 * mscale * math.log(scale) + 1.0
263
+
264
+
265
+ def yarn_linear_ramp_mask(min_i: int, max_i: int, dim: int) -> torch.Tensor:
266
+ if min_i == max_i:
267
+ max_i += 0.001
268
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min_i) / (max_i - min_i)
269
+ ramp_func = torch.clamp(linear_func, 0, 1)
270
+ return ramp_func
271
+
272
+
273
+ class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
274
+ """
275
+ Extended Yarn-based Rotary Embedding from your original code.
276
+ """
277
+ def __init__(
278
+ self,
279
+ dim: int,
280
+ max_position_embeddings: int = 2048,
281
+ base: int = 10000,
282
+ device: Optional[torch.device] = None,
283
+ scaling_factor: float = 1.0,
284
+ original_max_position_embeddings: int = 4096,
285
+ beta_fast: float = 32,
286
+ beta_slow: float = 1,
287
+ mscale: float = 1,
288
+ mscale_all_dim: float = 0,
289
+ ):
290
+ self.scaling_factor = scaling_factor
291
+ self.original_max_position_embeddings = original_max_position_embeddings
292
+ self.beta_fast = beta_fast
293
+ self.beta_slow = beta_slow
294
+ self.mscale = mscale
295
+ self.mscale_all_dim = mscale_all_dim
296
+ super().__init__(dim, max_position_embeddings, base, device)
297
+
298
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
299
+ self.max_seq_len_cached = seq_len
300
+ dim = self.dim
301
+
302
+ freq_extra = 1.0 / (
303
+ self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
304
+ )
305
+ freq_inter = 1.0 / (
306
+ self.scaling_factor
307
+ * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
308
+ )
309
+
310
+ low, high = yarn_find_correction_range(
311
+ self.beta_fast,
312
+ self.beta_slow,
313
+ dim,
314
+ self.base,
315
+ self.original_max_position_embeddings,
316
+ )
317
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32)
318
+ inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
319
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
320
+
321
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
322
+ freqs = torch.outer(t, inv_freq)
323
+ _mscale = float(
324
+ yarn_get_mscale(self.scaling_factor, self.mscale)
325
+ / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
326
+ )
327
+
328
+ emb = torch.cat((freqs, freqs), dim=-1)
329
+ self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False)
330
+ self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
331
+
332
+
333
+ # ==============================================================================
334
+ # General Rotary helper functions
335
+ # ==============================================================================
336
+
337
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
338
+ """Rotates half the hidden dims of the input."""
339
+ x1 = x[..., : x.shape[-1] // 2]
340
+ x2 = x[..., x.shape[-1] // 2 :]
341
+ return torch.cat((-x2, x1), dim=-1)
342
+
343
+
344
+ def apply_rotary_pos_emb(
345
+ q: torch.Tensor,
346
+ k: torch.Tensor,
347
+ cos: torch.Tensor,
348
+ sin: torch.Tensor,
349
+ position_ids: torch.Tensor,
350
+ unsqueeze_dim: int = 1
351
+ ):
352
+ """
353
+ Applies Rotary Position Embedding to the query and key tensors.
354
+ """
355
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
356
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
357
+
358
+ b, h, s, d = q.shape
359
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
360
+
361
+ b, h, s, d = k.shape
362
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
363
+
364
+ q_embed = (q * cos) + (rotate_half(q) * sin)
365
+ k_embed = (k * cos) + (rotate_half(k) * sin)
366
+ return q_embed, k_embed
367
+
368
+
369
+ # ==============================================================================
370
+ # MLP and MoE Modules
371
+ # ==============================================================================
372
+
373
+ class DeepseekV3MLP(nn.Module):
374
+ """
375
+ Simple MLP block with gating (SwiGLU style).
376
+ """
377
+ def __init__(self, config: DeepseekV3Config,
378
+ hidden_size: Optional[int] = None,
379
+ intermediate_size: Optional[int] = None):
380
+ super().__init__()
381
+ self.config = config
382
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
383
+ self.intermediate_size = (
384
+ config.intermediate_size if intermediate_size is None else intermediate_size
385
+ )
386
+
387
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
388
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
389
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
390
+ self.act_fn = ACT2FN[config.hidden_act]
391
+
392
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
393
+ gated = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
394
+ return self.down_proj(gated)
395
+
396
+
397
+ class MoEGate(nn.Module):
398
+ """
399
+ Expert gating mechanism for MoE. This could be enhanced with other gating strategies.
400
+ """
401
+ def __init__(self, config: DeepseekV3Config):
402
+ super().__init__()
403
+ self.config = config
404
+ self.top_k = config.num_experts_per_tok
405
+ self.n_routed_experts = config.n_routed_experts
406
+ self.routed_scaling_factor = config.routed_scaling_factor
407
+ self.scoring_func = config.scoring_func
408
+ self.seq_aux = config.seq_aux
409
+ self.topk_method = config.topk_method
410
+ self.n_group = config.n_group
411
+ self.topk_group = config.topk_group
412
+
413
+ self.norm_topk_prob = config.norm_topk_prob
414
+ self.gating_dim = config.hidden_size
415
+
416
+ # Gating weight
417
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
418
+
419
+ if self.topk_method == "noaux_tc":
420
+ self.e_score_correction_bias = nn.Parameter(
421
+ torch.empty((self.n_routed_experts))
422
+ )
423
+
424
+ self.reset_parameters()
425
+
426
+ def reset_parameters(self):
427
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
428
+ if self.topk_method == "noaux_tc":
429
+ nn.init.constant_(self.e_score_correction_bias, 0.0)
430
+
431
+ def forward(self, hidden_states: torch.Tensor):
432
+ """
433
+ Compute gating scores and select top-k experts.
434
+ """
435
+ bsz, seq_len, h = hidden_states.shape
436
+
437
+ # 1) Compute gating scores
438
+ logits = F.linear(hidden_states.float(), self.weight.float(), None)
439
+ if self.scoring_func == "sigmoid":
440
+ scores = logits.sigmoid()
441
+ else:
442
+ raise NotImplementedError(
443
+ f"Unsupported gating scoring function: {self.scoring_func}"
444
+ )
445
+
446
+ # 2) TopK selection
447
+ if self.topk_method == "noaux_tc":
448
+ # This is a specialized approach from your original code
449
+ # IMPROVEMENT: Could consider generalizing to top2 gating or other advanced techniques
450
+ scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
451
+ group_scores = (
452
+ scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
453
+ )
454
+ group_idx = torch.topk(
455
+ group_scores, k=self.topk_group, dim=-1, sorted=False
456
+ )[1] # [n, top_k_group]
457
+ group_mask = torch.zeros_like(group_scores)
458
+ group_mask.scatter_(1, group_idx, 1)
459
+ score_mask = group_mask.unsqueeze(-1).expand(
460
+ bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
461
+ ).reshape(bsz * seq_len, -1)
462
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
463
+ _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
464
+ topk_weight = scores_for_choice.gather(1, topk_idx)
465
+ else:
466
+ raise NotImplementedError(
467
+ f"Unsupported topk_method: {self.topk_method}"
468
+ )
469
+
470
+ # 3) Norm gate to sum to 1 if top_k > 1
471
+ if self.top_k > 1 and self.norm_topk_prob:
472
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
473
+ topk_weight = topk_weight / denominator
474
+
475
+ # 4) Multiply scaling factor
476
+ topk_weight = topk_weight * self.routed_scaling_factor
477
+
478
+ return topk_idx, topk_weight
479
+
480
+
481
+ class DeepseekV3MoE(nn.Module):
482
+ """
483
+ A mixture-of-experts module. Uses gating to route tokens to certain experts.
484
+ """
485
+ def __init__(self, config: DeepseekV3Config):
486
+ super().__init__()
487
+ self.config = config
488
+ self.num_experts_per_tok = config.num_experts_per_tok
489
+
490
+ self.ep_size = getattr(config, "ep_size", 1)
491
+ self.experts_per_rank = config.n_routed_experts
492
+ self.ep_rank = 0
493
+ if self.ep_size > 1:
494
+ assert self.ep_size == dist.get_world_size()
495
+ self.experts_per_rank = config.n_routed_experts // config.ep_size
496
+ self.ep_rank = dist.get_rank()
497
+
498
+ # Build experts
499
+ experts_list = []
500
+ for i in range(config.n_routed_experts):
501
+ # only build if belongs to current rank
502
+ if self.ep_size > 1:
503
+ if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank:
504
+ experts_list.append(
505
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
506
+ )
507
+ else:
508
+ experts_list.append(None)
509
+ else:
510
+ experts_list.append(
511
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
512
+ )
513
+ self.experts = nn.ModuleList(experts_list)
514
+
515
+ # Gate
516
+ self.gate = MoEGate(config)
517
+
518
+ # Optionally shared experts
519
+ if config.n_shared_experts is not None:
520
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
521
+ self.shared_experts = DeepseekV3MLP(
522
+ config=config, intermediate_size=intermediate_size
523
+ )
524
+ else:
525
+ self.shared_experts = None
526
+
527
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
528
+ identity = hidden_states
529
+ orig_shape = hidden_states.shape
530
+
531
+ topk_idx, topk_weight = self.gate(hidden_states)
532
+ # Flatten
533
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
534
+
535
+ # Inference
536
+ if not self.training:
537
+ y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
538
+ else:
539
+ # For training, you’d typically do a distributed MoE approach
540
+ # or a specialized approach from your original code.
541
+ # This placeholder just calls `moe_infer` for demonstration.
542
+ y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
543
+
544
+ # Add shared experts if present
545
+ if self.shared_experts is not None:
546
+ y = y + self.shared_experts(identity)
547
+
548
+ return y
549
+
550
+ @torch.no_grad()
551
+ def moe_infer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
552
+ """
553
+ MoE inference path for each token. This code can be parallelized or distributed for better performance.
554
+ """
555
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
556
+ cnts.scatter_(1, topk_ids, 1)
557
+ tokens_per_expert = cnts.sum(dim=0)
558
+ idxs = topk_ids.view(-1).argsort()
559
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
560
+ sorted_tokens_shape = sorted_tokens.shape
561
+
562
+ # Handle distribution if ep_size>1
563
+ if self.ep_size > 1:
564
+ tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
565
+ tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
566
+ dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
567
+ output_splits = (
568
+ tokens_per_expert_group.view(self.ep_size, self.experts_per_rank)
569
+ .sum(1)
570
+ .cpu()
571
+ .numpy()
572
+ .tolist()
573
+ )
574
+ gathered_tokens = sorted_tokens.new_empty(
575
+ tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
576
+ )
577
+ input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
578
+ dist.all_to_all(
579
+ list(gathered_tokens.split(input_split_sizes)),
580
+ list(sorted_tokens.split(input_split_sizes)),
581
+ )
582
+ tokens_per_expert_post_gather = tokens_per_expert_group.view(
583
+ self.ep_size, self.experts_per_rank
584
+ ).sum(dim=0)
585
+ gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
586
+ s = 0
587
+ for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
588
+ gatherd_idxs[s : s + k] = i % self.experts_per_rank
589
+ s += k
590
+ gatherd_idxs = gatherd_idxs.argsort()
591
+ sorted_tokens = gathered_tokens[gatherd_idxs]
592
+ tokens_per_expert = tokens_per_expert_post_gather
593
+
594
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
595
+
596
+ outputs = []
597
+ start_idx = 0
598
+ # Forward pass for each expert’s assigned tokens
599
+ for i, num_tokens in enumerate(tokens_per_expert):
600
+ end_idx = start_idx + num_tokens
601
+ if num_tokens == 0:
602
+ continue
603
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
604
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
605
+ expert_out = expert(tokens_for_this_expert) if expert else tokens_for_this_expert
606
+ outputs.append(expert_out)
607
+ start_idx = end_idx
608
+
609
+ outs = (
610
+ torch.cat(outputs, dim=0)
611
+ if len(outputs)
612
+ else sorted_tokens.new_empty(0)
613
+ )
614
+
615
+ if self.ep_size > 1:
616
+ new_x = torch.empty_like(outs)
617
+ new_x[gatherd_idxs] = outs
618
+ gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
619
+ dist.all_to_all(
620
+ list(gathered_tokens.split(input_split_sizes)),
621
+ list(new_x.split(output_splits)),
622
+ )
623
+ outs = gathered_tokens
624
+
625
+ new_x = torch.empty_like(outs)
626
+ new_x[idxs] = outs
627
+ final_out = (
628
+ new_x.view(*topk_ids.shape, -1)
629
+ .type(topk_weight.dtype)
630
+ .mul_(topk_weight.unsqueeze(dim=-1))
631
+ .sum(dim=1)
632
+ .type(new_x.dtype)
633
+ )
634
+ return final_out
635
+
636
+
637
+ # Utility to repeat KV states if needed
638
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
639
+ """
640
+ Equivalent to torch.repeat_interleave(x, dim=1, repeats=n_rep).
641
+ For dimension usage: (batch, num_heads, seqlen, head_dim).
642
+ """
643
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
644
+ if n_rep == 1:
645
+ return hidden_states
646
+ hidden_states = hidden_states[:, :, None, :, :].expand(
647
+ batch, num_key_value_heads, n_rep, slen, head_dim
648
+ )
649
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
650
+
651
+
652
+ # ==============================================================================
653
+ # Attention Modules
654
+ # ==============================================================================
655
+
656
+ class DeepseekV3Attention(nn.Module):
657
+ """
658
+ Standard multi-headed attention for Deepseek.
659
+ """
660
+ def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
661
+ super().__init__()
662
+ self.config = config
663
+ self.layer_idx = layer_idx
664
+
665
+ self.attention_dropout = config.attention_dropout
666
+ self.hidden_size = config.hidden_size
667
+ self.num_heads = config.num_attention_heads
668
+
669
+ self.max_position_embeddings = config.max_position_embeddings
670
+ self.rope_theta = config.rope_theta
671
+ self.q_lora_rank = config.q_lora_rank
672
+ self.qk_rope_head_dim = config.qk_rope_head_dim
673
+ self.kv_lora_rank = config.kv_lora_rank
674
+ self.v_head_dim = config.v_head_dim
675
+ self.qk_nope_head_dim = config.qk_nope_head_dim
676
+ self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
677
+
678
+ self.is_causal = True
679
+
680
+ # Q-proj
681
+ if self.q_lora_rank is None:
682
+ self.q_proj = nn.Linear(
683
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
684
+ )
685
+ else:
686
+ self.q_a_proj = nn.Linear(
687
+ self.hidden_size, config.q_lora_rank, bias=config.attention_bias
688
+ )
689
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
690
+ self.q_b_proj = nn.Linear(
691
+ config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
692
+ )
693
+
694
+ # K,V-proj (MQA style)
695
+ self.kv_a_proj_with_mqa = nn.Linear(
696
+ self.hidden_size,
697
+ config.kv_lora_rank + config.qk_rope_head_dim,
698
+ bias=config.attention_bias,
699
+ )
700
+ self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
701
+ self.kv_b_proj = nn.Linear(
702
+ config.kv_lora_rank,
703
+ self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
704
+ bias=False,
705
+ )
706
+
707
+ # Out proj
708
+ self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias)
709
+
710
+ # Build the rotary embedding
711
+ self._init_rope()
712
+
713
+ # IMPROVEMENT: Custom softmax scaling, adapt for Yarn scaling
714
+ self.softmax_scale = self.q_head_dim ** (-0.5)
715
+ if self.config.rope_scaling is not None:
716
+ # E.g. yarn-based scaling can factor in additional multipliers
717
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
718
+ scaling_factor = self.config.rope_scaling["factor"]
719
+ if mscale_all_dim:
720
+ # Simple example using the Yarn approach
721
+ self.softmax_scale *= yarn_get_mscale(scaling_factor, mscale_all_dim) ** 2
722
+
723
+ def _init_rope(self):
724
+ """
725
+ Initializes RoPE depending on scaling type: linear, dynamic, yarn, etc.
726
+ """
727
+ if self.config.rope_scaling is None:
728
+ self.rotary_emb = DeepseekV3RotaryEmbedding(
729
+ self.qk_rope_head_dim,
730
+ max_position_embeddings=self.max_position_embeddings,
731
+ base=self.rope_theta,
732
+ )
733
+ else:
734
+ scaling_type = self.config.rope_scaling["type"]
735
+ scaling_factor = self.config.rope_scaling["factor"]
736
+
737
+ if scaling_type == "linear":
738
+ self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
739
+ self.qk_rope_head_dim,
740
+ max_position_embeddings=self.max_position_embeddings,
741
+ scaling_factor=scaling_factor,
742
+ base=self.rope_theta,
743
+ )
744
+ elif scaling_type == "dynamic":
745
+ self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(
746
+ self.qk_rope_head_dim,
747
+ max_position_embeddings=self.max_position_embeddings,
748
+ scaling_factor=scaling_factor,
749
+ base=self.rope_theta,
750
+ )
751
+ elif scaling_type == "yarn":
752
+ kwargs = {
753
+ key: self.config.rope_scaling[key]
754
+ for key in [
755
+ "original_max_position_embeddings",
756
+ "beta_fast",
757
+ "beta_slow",
758
+ "mscale",
759
+ "mscale_all_dim",
760
+ ]
761
+ if key in self.config.rope_scaling
762
+ }
763
+ self.rotary_emb = DeepseekV3YarnRotaryEmbedding(
764
+ self.qk_rope_head_dim,
765
+ max_position_embeddings=self.max_position_embeddings,
766
+ scaling_factor=scaling_factor,
767
+ base=self.rope_theta,
768
+ **kwargs,
769
+ )
770
+ else:
771
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
772
+
773
+ def forward(
774
+ self,
775
+ hidden_states: torch.Tensor,
776
+ attention_mask: Optional[torch.Tensor] = None,
777
+ position_ids: Optional[torch.LongTensor] = None,
778
+ past_key_value: Optional[Cache] = None,
779
+ output_attentions: bool = False,
780
+ use_cache: bool = False,
781
+ **kwargs,
782
+ ):
783
+ """
784
+ Standard forward pass for multi-headed self-attention.
785
+ """
786
+ if "padding_mask" in kwargs:
787
+ warnings.warn(
788
+ "Passing `padding_mask` is deprecated. Use `attention_mask` instead."
789
+ )
790
+
791
+ bsz, q_len, _ = hidden_states.size()
792
+
793
+ # Q projection
794
+ if self.q_lora_rank is None:
795
+ q = self.q_proj(hidden_states)
796
+ else:
797
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
798
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
799
+ q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
800
+
801
+ # MQA: K,V from single projection
802
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
803
+ compressed_kv, k_pe = torch.split(
804
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
805
+ )
806
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
807
+ kv = (
808
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
809
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
810
+ .transpose(1, 2)
811
+ )
812
+ k_nope, value_states = torch.split(
813
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
814
+ )
815
+ kv_seq_len = value_states.shape[-2]
816
+ if past_key_value is not None:
817
+ if self.layer_idx is None:
818
+ raise ValueError(
819
+ f"Missing `layer_idx` for caching. Provide layer_idx in {self.__class__.__name__}."
820
+ )
821
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
822
+
823
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
824
+
825
+ # Apply rotary to query and key
826
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
827
+
828
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
829
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
830
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
831
+
832
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
833
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
834
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
835
+
836
+ if past_key_value is not None:
837
+ cache_kwargs = {"sin": sin, "cos": cos} # for RoPE
838
+ key_states, value_states = past_key_value.update(
839
+ key_states, value_states, self.layer_idx, cache_kwargs
840
+ )
841
+
842
+ attn_weights = (torch.matmul(query_states, key_states.transpose(2, 3))
843
+ * self.softmax_scale)
844
+
845
+ if attention_mask is not None:
846
+ attn_weights = attn_weights + attention_mask
847
+
848
+ # Use float32 for more stable softmax
849
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
850
+ attn_weights = nn.functional.dropout(
851
+ attn_weights, p=self.attention_dropout, training=self.training
852
+ )
853
+ attn_output = torch.matmul(attn_weights, value_states)
854
+
855
+ attn_output = attn_output.transpose(1, 2).contiguous()
856
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
857
+
858
+ attn_output = self.o_proj(attn_output)
859
+
860
+ if not output_attentions:
861
+ attn_weights = None
862
+
863
+ return attn_output, attn_weights, past_key_value
864
+
865
+
866
+ class DeepseekV3FlashAttention2(DeepseekV3Attention):
867
+ """
868
+ DeepseekV3 flash attention module. Inherits the same Q/K/V projections from DeepseekV3Attention.
869
+ Only the forward pass changes to use flash_attn APIs.
870
+ """
871
+ def __init__(self, *args, **kwargs):
872
+ super().__init__(*args, **kwargs)
873
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
874
+
875
+ def forward(
876
+ self,
877
+ hidden_states: torch.Tensor,
878
+ attention_mask: Optional[torch.Tensor] = None,
879
+ position_ids: Optional[torch.LongTensor] = None,
880
+ past_key_value: Optional[Cache] = None,
881
+ output_attentions: bool = False,
882
+ use_cache: bool = False,
883
+ **kwargs,
884
+ ):
885
+ # Overridden forward logic using flash attention
886
+ if "padding_mask" in kwargs:
887
+ warnings.warn(
888
+ "Passing `padding_mask` is deprecated. Use `attention_mask` instead."
889
+ )
890
+ attention_mask = kwargs.pop("padding_mask")
891
+
892
+ output_attentions = False # flash attn 2 doesn't expose attention probs
893
+
894
+ bsz, q_len, _ = hidden_states.shape
895
+ if self.q_lora_rank is None:
896
+ q = self.q_proj(hidden_states)
897
+ else:
898
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
899
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
900
+ q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
901
+
902
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
903
+ compressed_kv, k_pe = torch.split(
904
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
905
+ )
906
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
907
+ kv = (
908
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
909
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
910
+ .transpose(1, 2)
911
+ )
912
+ k_nope, value_states = torch.split(
913
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
914
+ )
915
+ kv_seq_len = value_states.shape[-2]
916
+ if past_key_value is not None:
917
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
918
+
919
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
920
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
921
+
922
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
923
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
924
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
925
+
926
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
927
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
928
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
929
+
930
+ if self.q_head_dim != self.v_head_dim:
931
+ # Pad if needed
932
+ value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
933
+
934
+ if past_key_value is not None:
935
+ cache_kwargs = {"sin": sin, "cos": cos}
936
+ key_states, value_states = past_key_value.update(
937
+ key_states, value_states, self.layer_idx, cache_kwargs
938
+ )
939
+
940
+ # Prepare for flash-attn which needs [bsz, seqlen, n_heads, head_dim]
941
+ query_states = query_states.transpose(1, 2)
942
+ key_states = key_states.transpose(1, 2)
943
+ value_states = value_states.transpose(1, 2)
944
+
945
+ dropout_rate = self.attention_dropout if self.training else 0.0
946
+
947
+ # Possibly revert to original Q,K,V dtype if upcast to float32
948
+ input_dtype = query_states.dtype
949
+ if input_dtype == torch.float32:
950
+ # Attempt to revert to original param dtype if different
951
+ target_dtype = (
952
+ self.q_proj.weight.dtype
953
+ if self.q_lora_rank is None
954
+ else self.q_a_proj.weight.dtype
955
+ )
956
+ query_states = query_states.to(target_dtype)
957
+ key_states = key_states.to(target_dtype)
958
+ value_states = value_states.to(target_dtype)
959
+
960
+ # Flash attention pass
961
+ attn_output = self._flash_attention_forward(
962
+ query_states,
963
+ key_states,
964
+ value_states,
965
+ attention_mask,
966
+ q_len,
967
+ dropout=dropout_rate,
968
+ softmax_scale=self.softmax_scale,
969
+ )
970
+
971
+ if self.q_head_dim != self.v_head_dim:
972
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
973
+
974
+ # [bsz, seqlen, n_heads, head_dim]
975
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
976
+ attn_output = self.o_proj(attn_output)
977
+
978
+ return attn_output, None, past_key_value
979
+
980
+ def _flash_attention_forward(
981
+ self,
982
+ query_states: torch.Tensor,
983
+ key_states: torch.Tensor,
984
+ value_states: torch.Tensor,
985
+ attention_mask: Optional[torch.Tensor],
986
+ query_length: int,
987
+ dropout: float = 0.0,
988
+ softmax_scale: Optional[float] = None,
989
+ ) -> torch.Tensor:
990
+ """
991
+ Wraps the flash-attn calls. If attention_mask has padding, we unpad first.
992
+ """
993
+ if not self._flash_attn_uses_top_left_mask:
994
+ causal = self.is_causal
995
+ else:
996
+ # For flash_attn<2.1.0
997
+ causal = self.is_causal and query_length != 1
998
+
999
+ if attention_mask is not None:
1000
+ batch_size = query_states.shape[0]
1001
+ (query_states,
1002
+ key_states,
1003
+ value_states,
1004
+ indices_q,
1005
+ (cu_seqlens_q, cu_seqlens_k),
1006
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k)) = self._upad_input(
1007
+ query_states, key_states, value_states, attention_mask, query_length
1008
+ )
1009
+ attn_output_unpad = flash_attn_varlen_func(
1010
+ query_states,
1011
+ key_states,
1012
+ value_states,
1013
+ cu_seqlens_q=cu_seqlens_q,
1014
+ cu_seqlens_k=cu_seqlens_k,
1015
+ max_seqlen_q=max_seqlen_in_batch_q,
1016
+ max_seqlen_k=max_seqlen_in_batch_k,
1017
+ dropout_p=dropout,
1018
+ softmax_scale=softmax_scale,
1019
+ causal=causal,
1020
+ )
1021
+ attn_output = pad_input(
1022
+ attn_output_unpad, indices_q, batch_size, query_length
1023
+ )
1024
+ else:
1025
+ attn_output = flash_attn_func(
1026
+ query_states,
1027
+ key_states,
1028
+ value_states,
1029
+ dropout,
1030
+ softmax_scale=softmax_scale,
1031
+ causal=causal,
1032
+ )
1033
+
1034
+ return attn_output
1035
+
1036
+ def _upad_input(
1037
+ self,
1038
+ query_layer: torch.Tensor,
1039
+ key_layer: torch.Tensor,
1040
+ value_layer: torch.Tensor,
1041
+ attention_mask: torch.Tensor,
1042
+ query_length: int,
1043
+ ):
1044
+ """
1045
+ Unpads the Q, K, and V for FlashAttention in variable-length mode.
1046
+ """
1047
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1048
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1049
+
1050
+ key_layer = index_first_axis(
1051
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1052
+ indices_k,
1053
+ )
1054
+ value_layer = index_first_axis(
1055
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1056
+ indices_k,
1057
+ )
1058
+ if query_length == kv_seq_len:
1059
+ query_layer = index_first_axis(
1060
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
1061
+ indices_k,
1062
+ )
1063
+ cu_seqlens_q = cu_seqlens_k
1064
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
1065
+ indices_q = indices_k
1066
+ elif query_length == 1:
1067
+ max_seqlen_in_batch_q = 1
1068
+ cu_seqlens_q = torch.arange(
1069
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1070
+ )
1071
+ indices_q = cu_seqlens_q[:-1]
1072
+ query_layer = query_layer.squeeze(1)
1073
+ else:
1074
+ # handle partial left padding
1075
+ attention_mask = attention_mask[:, -query_length:]
1076
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1077
+ query_layer, attention_mask
1078
+ )
1079
+
1080
+ return (
1081
+ query_layer,
1082
+ key_layer,
1083
+ value_layer,
1084
+ indices_q,
1085
+ (cu_seqlens_q, cu_seqlens_k),
1086
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1087
+ )
1088
+
1089
+
1090
+ # Attach the attention classes in a dictionary for easy selection
1091
+ ATTENTION_CLASSES = {
1092
+ "eager": DeepseekV3Attention,
1093
+ "flash_attention_2": DeepseekV3FlashAttention2,
1094
+ }
1095
+
1096
+
1097
+ # ==============================================================================
1098
+ # Decoder Layer
1099
+ # ==============================================================================
1100
+
1101
+ class DeepseekV3DecoderLayer(nn.Module):
1102
+ """
1103
+ Single decoder layer composed of self-attention and MLP (optionally MoE).
1104
+ """
1105
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
1106
+ super().__init__()
1107
+ self.hidden_size = config.hidden_size
1108
+
1109
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
1110
+ config=config, layer_idx=layer_idx
1111
+ )
1112
+
1113
+ # Optionally use MoE
1114
+ if (
1115
+ config.n_routed_experts is not None
1116
+ and layer_idx >= config.first_k_dense_replace
1117
+ and layer_idx % config.moe_layer_freq == 0
1118
+ ):
1119
+ self.mlp = DeepseekV3MoE(config)
1120
+ else:
1121
+ self.mlp = DeepseekV3MLP(config)
1122
+
1123
+ self.input_layernorm = DeepseekV3RMSNorm(
1124
+ config.hidden_size, eps=config.rms_norm_eps
1125
+ )
1126
+ self.post_attention_layernorm = DeepseekV3RMSNorm(
1127
+ config.hidden_size, eps=config.rms_norm_eps
1128
+ )
1129
+
1130
+ def forward(
1131
+ self,
1132
+ hidden_states: torch.Tensor,
1133
+ attention_mask: Optional[torch.Tensor] = None,
1134
+ position_ids: Optional[torch.LongTensor] = None,
1135
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1136
+ output_attentions: Optional[bool] = False,
1137
+ use_cache: Optional[bool] = False,
1138
+ **kwargs
1139
+ ):
1140
+ """
1141
+ Forward pass for one Deepseek decoder layer.
1142
+ """
1143
+ residual = hidden_states
1144
+
1145
+ # Pre-attention norm
1146
+ hidden_states = self.input_layernorm(hidden_states)
1147
+
1148
+ # Self-attention
1149
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1150
+ hidden_states=hidden_states,
1151
+ attention_mask=attention_mask,
1152
+ position_ids=position_ids,
1153
+ past_key_value=past_key_value,
1154
+ output_attentions=output_attentions,
1155
+ use_cache=use_cache,
1156
+ **kwargs,
1157
+ )
1158
+ hidden_states = residual + hidden_states
1159
+
1160
+ # Post-attention norm
1161
+ residual = hidden_states
1162
+ hidden_states = self.post_attention_layernorm(hidden_states)
1163
+
1164
+ # MLP or MoE
1165
+ hidden_states = self.mlp(hidden_states)
1166
+ hidden_states = residual + hidden_states
1167
+
1168
+ outputs = (hidden_states,)
1169
+ if output_attentions:
1170
+ outputs += (self_attn_weights,)
1171
+
1172
+ if use_cache:
1173
+ outputs += (present_key_value,)
1174
+
1175
+ return outputs
1176
+
1177
+
1178
+ # ==============================================================================
1179
+ # Main Model Classes
1180
+ # ==============================================================================
1181
+
1182
+ DeepseekV3_START_DOCSTRING = r"""
1183
+ This model inherits from `PreTrainedModel`. Check the superclass documentation
1184
+ for the generic methods the library implements for all its model (such as loading or saving, etc.)
1185
+ """
1186
+
1187
+ class DeepseekV3PreTrainedModel(PreTrainedModel):
1188
+ config_class = DeepseekV3Config
1189
+ base_model_prefix = "model"
1190
+ supports_gradient_checkpointing = True
1191
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
1192
+ _skip_keys_device_placement = "past_key_values"
1193
+ _supports_flash_attn_2 = True
1194
+ _supports_cache_class = True
1195
+
1196
+ def _init_weights(self, module):
1197
+ # IMPROVEMENT: Could add more robust initialization or variants (e.g., Xavier)
1198
+ std = self.config.initializer_range
1199
+ if isinstance(module, nn.Linear):
1200
+ module.weight.data.normal_(mean=0.0, std=std)
1201
+ if module.bias is not None:
1202
+ module.bias.data.zero_()
1203
+ elif isinstance(module, nn.Embedding):
1204
+ module.weight.data.normal_(mean=0.0, std=std)
1205
+ if module.padding_idx is not None:
1206
+ module.weight.data[module.padding_idx].zero_()
1207
+
1208
+ def gradient_checkpointing_enable(self):
1209
+ self.gradient_checkpointing = True
1210
+
1211
+ def gradient_checkpointing_disable(self):
1212
+ self.gradient_checkpointing = False
1213
+
1214
+
1215
+ DeepseekV3_INPUTS_DOCSTRING = r"""
1216
+ Args:
1217
+ input_ids (torch.LongTensor): shape `(batch_size, sequence_length)`
1218
+ attention_mask (torch.Tensor): shape `(batch_size, sequence_length)` or `(batch_size, 1, seq_len, seq_len)`, optional.
1219
+ position_ids (torch.LongTensor): shape `(batch_size, sequence_length)`, optional.
1220
+ past_key_values (Cache or tuple(tuple(torch.FloatTensor))), optional:
1221
+ Pre-computed hidden-states (key and values) that can be used to speed up sequential decoding.
1222
+ inputs_embeds (torch.FloatTensor): shape `(batch_size, sequence_length, hidden_size)`, optional.
1223
+ use_cache (bool), optional
1224
+ output_attentions (bool), optional
1225
+ output_hidden_states (bool), optional
1226
+ return_dict (bool), optional
1227
+ """
1228
+
1229
+ @add_start_docstrings(
1230
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1231
+ DeepseekV3_START_DOCSTRING,
1232
+ )
1233
+ class DeepseekV3Model(DeepseekV3PreTrainedModel):
1234
+ """
1235
+ Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a `DeepseekV3DecoderLayer`.
1236
+ """
1237
+ def __init__(self, config: DeepseekV3Config):
1238
+ super().__init__(config)
1239
+ self.padding_idx = config.pad_token_id
1240
+ self.vocab_size = config.vocab_size
1241
+
1242
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1243
+
1244
+ # Build decoder layers
1245
+ self.layers = nn.ModuleList([
1246
+ DeepseekV3DecoderLayer(config, layer_idx)
1247
+ for layer_idx in range(config.num_hidden_layers)
1248
+ ])
1249
+
1250
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1251
+ self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1252
+
1253
+ self.gradient_checkpointing = False
1254
+ self.post_init()
1255
+
1256
+ def get_input_embeddings(self) -> nn.Embedding:
1257
+ return self.embed_tokens
1258
+
1259
+ def set_input_embeddings(self, value: nn.Embedding):
1260
+ self.embed_tokens = value
1261
+
1262
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1263
+ def forward(
1264
+ self,
1265
+ input_ids: Optional[torch.LongTensor] = None,
1266
+ attention_mask: Optional[torch.Tensor] = None,
1267
+ position_ids: Optional[torch.LongTensor] = None,
1268
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
1269
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1270
+ use_cache: Optional[bool] = None,
1271
+ output_attentions: Optional[bool] = None,
1272
+ output_hidden_states: Optional[bool] = None,
1273
+ return_dict: Optional[bool] = None,
1274
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1275
+
1276
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1277
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None
1278
+ else self.config.output_hidden_states)
1279
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1280
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1281
+
1282
+ if input_ids is not None and inputs_embeds is not None:
1283
+ raise ValueError("Cannot specify both input_ids and inputs_embeds at the same time.")
1284
+ elif input_ids is not None:
1285
+ batch_size, seq_length = input_ids.shape[:2]
1286
+ elif inputs_embeds is not None:
1287
+ batch_size, seq_length = inputs_embeds.shape[:2]
1288
+ else:
1289
+ raise ValueError("You must specify either input_ids or inputs_embeds.")
1290
+
1291
+ past_key_values_length = 0
1292
+ use_legacy_cache = False
1293
+ if use_cache and past_key_values is not None:
1294
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1295
+ if use_legacy_cache:
1296
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1297
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1298
+
1299
+ if position_ids is None:
1300
+ device = (input_ids.device if input_ids is not None else inputs_embeds.device)
1301
+ position_ids = torch.arange(
1302
+ past_key_values_length,
1303
+ seq_length + past_key_values_length,
1304
+ dtype=torch.long,
1305
+ device=device
1306
+ )
1307
+ position_ids = position_ids.unsqueeze(0)
1308
+
1309
+ if inputs_embeds is None:
1310
+ inputs_embeds = self.embed_tokens(input_ids)
1311
+
1312
+ # If flash attention is used, we pass 2D mask to the layers
1313
+ if self._use_flash_attention_2:
1314
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1315
+ else:
1316
+ # standard 4D mask
1317
+ attention_mask = _prepare_4d_causal_attention_mask(
1318
+ attention_mask,
1319
+ (batch_size, seq_length),
1320
+ inputs_embeds,
1321
+ past_key_values_length,
1322
+ )
1323
+
1324
+ hidden_states = inputs_embeds
1325
+
1326
+ all_hidden_states = () if output_hidden_states else None
1327
+ all_self_attns = () if output_attentions else None
1328
+ next_decoder_cache = None
1329
+
1330
+ for idx, decoder_layer in enumerate(self.layers):
1331
+ if output_hidden_states:
1332
+ all_hidden_states += (hidden_states,)
1333
+
1334
+ # Potential gradient checkpointing
1335
+ if self.gradient_checkpointing and self.training:
1336
+ def create_custom_forward(module):
1337
+ def custom_forward(*inputs):
1338
+ return module(*inputs, output_attentions=output_attentions, use_cache=use_cache)
1339
+ return custom_forward
1340
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1341
+ create_custom_forward(decoder_layer),
1342
+ hidden_states,
1343
+ attention_mask,
1344
+ position_ids,
1345
+ past_key_values
1346
+ )
1347
+ else:
1348
+ layer_outputs = decoder_layer(
1349
+ hidden_states,
1350
+ attention_mask=attention_mask,
1351
+ position_ids=position_ids,
1352
+ past_key_value=past_key_values,
1353
+ output_attentions=output_attentions,
1354
+ use_cache=use_cache,
1355
+ )
1356
+
1357
+ hidden_states = layer_outputs[0]
1358
+ if use_cache:
1359
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1360
+
1361
+ if output_attentions:
1362
+ all_self_attns += (layer_outputs[1],)
1363
+
1364
+ hidden_states = self.norm(hidden_states)
1365
+ if output_hidden_states:
1366
+ all_hidden_states += (hidden_states,)
1367
+
1368
+ # Prepare next_cache
1369
+ next_cache = None
1370
+ if use_cache:
1371
+ next_cache = (
1372
+ next_decoder_cache.to_legacy_cache()
1373
+ if use_legacy_cache
1374
+ else next_decoder_cache
1375
+ )
1376
+
1377
+ if not return_dict:
1378
+ return tuple(
1379
+ v
1380
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1381
+ if v is not None
1382
+ )
1383
+
1384
+ return BaseModelOutputWithPast(
1385
+ last_hidden_state=hidden_states,
1386
+ past_key_values=next_cache,
1387
+ hidden_states=all_hidden_states,
1388
+ attentions=all_self_attns,
1389
+ )
1390
+
1391
+
1392
+ # ==============================================================================
1393
+ # Causal LM Model
1394
+ # ==============================================================================
1395
+
1396
+ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1397
+ _tied_weights_keys = ["lm_head.weight"]
1398
+
1399
+ def __init__(self, config: DeepseekV3Config):
1400
+ super().__init__(config)
1401
+ self.model = DeepseekV3Model(config)
1402
+ self.vocab_size = config.vocab_size
1403
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1404
+
1405
+ self.post_init()
1406
+
1407
+ def get_input_embeddings(self) -> nn.Embedding:
1408
+ return self.model.embed_tokens
1409
+
1410
+ def set_input_embeddings(self, value: nn.Embedding):
1411
+ self.model.embed_tokens = value
1412
+
1413
+ def get_output_embeddings(self) -> nn.Module:
1414
+ return self.lm_head
1415
+
1416
+ def set_output_embeddings(self, new_embeddings: nn.Module):
1417
+ self.lm_head = new_embeddings
1418
+
1419
+ def set_decoder(self, decoder: nn.Module):
1420
+ self.model = decoder
1421
+
1422
+ def get_decoder(self) -> nn.Module:
1423
+ return self.model
1424
+
1425
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1426
+ @replace_return_docstrings(
1427
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1428
+ )
1429
+ def forward(
1430
+ self,
1431
+ input_ids: Optional[torch.LongTensor] = None,
1432
+ attention_mask: Optional[torch.Tensor] = None,
1433
+ position_ids: Optional[torch.LongTensor] = None,
1434
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
1435
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1436
+ labels: Optional[torch.LongTensor] = None,
1437
+ use_cache: Optional[bool] = None,
1438
+ output_attentions: Optional[bool] = None,
1439
+ output_hidden_states: Optional[bool] = None,
1440
+ return_dict: Optional[bool] = None,
1441
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1442
+ """
1443
+ Args:
1444
+ labels (torch.LongTensor of shape (batch_size, sequence_length), optional):
1445
+ For computing the language modeling loss. Indices in [0, config.vocab_size] or -100.
1446
+ """
1447
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1448
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None
1449
+ else self.config.output_hidden_states)
1450
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1451
+
1452
+ # Decoder forward
1453
+ outputs = self.model(
1454
+ input_ids=input_ids,
1455
+ attention_mask=attention_mask,
1456
+ position_ids=position_ids,
1457
+ past_key_values=past_key_values,
1458
+ inputs_embeds=inputs_embeds,
1459
+ use_cache=use_cache,
1460
+ output_attentions=output_attentions,
1461
+ output_hidden_states=output_hidden_states,
1462
+ return_dict=return_dict,
1463
+ )
1464
+
1465
+ hidden_states = outputs[0]
1466
+ logits = self.lm_head(hidden_states)
1467
+ logits = logits.float() # IMPROVEMENT: Could keep FP16 if stable
1468
+
1469
+ loss = None
1470
+ if labels is not None:
1471
+ # SHIFT
1472
+ shift_logits = logits[..., :-1, :].contiguous()
1473
+ shift_labels = labels[..., 1:].contiguous()
1474
+
1475
+ loss_fct = CrossEntropyLoss()
1476
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1477
+ shift_labels = shift_labels.view(-1)
1478
+ shift_labels = shift_labels.to(shift_logits.device)
1479
+ loss = loss_fct(shift_logits, shift_labels)
1480
+
1481
+ if not return_dict:
1482
+ output = (logits,) + outputs[1:]
1483
+ return (loss,) + output if loss is not None else output
1484
+
1485
+ return CausalLMOutputWithPast(
1486
+ loss=loss,
1487
+ logits=logits,
1488
+ past_key_values=outputs.past_key_values,
1489
+ hidden_states=outputs.hidden_states,
1490
+ attentions=outputs.attentions,
1491
+ )
1492
+
1493
+ def prepare_inputs_for_generation(
1494
+ self,
1495
+ input_ids: torch.Tensor,
1496
+ past_key_values: Optional[Union[Cache, Tuple]] = None,
1497
+ attention_mask: Optional[torch.Tensor] = None,
1498
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1499
+ **kwargs
1500
+ ):
1501
+ """
1502
+ Prepare inputs during generation loops.
1503
+ """
1504
+ if past_key_values is not None:
1505
+ if isinstance(past_key_values, Cache):
1506
+ cache_length = past_key_values.get_seq_length()
1507
+ past_length = past_key_values.seen_tokens
1508
+ max_cache_length = past_key_values.get_max_length()
1509
+ else:
1510
+ cache_length = past_length = past_key_values[0][0].shape[2]
1511
+ max_cache_length = None
1512
+
1513
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1514
+ # match up with the unprocessed tokens
1515
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1516
+ elif past_length < input_ids.shape[1]:
1517
+ input_ids = input_ids[:, past_length:]
1518
+
1519
+ if max_cache_length is not None and attention_mask is not None:
1520
+ if cache_length + input_ids.shape[1] > max_cache_length:
1521
+ attention_mask = attention_mask[:, -max_cache_length:]
1522
+
1523
+ position_ids = kwargs.get("position_ids", None)
1524
+ if attention_mask is not None and position_ids is None:
1525
+ position_ids = attention_mask.long().cumsum(-1) - 1
1526
+ position_ids.masked_fill_(attention_mask == 0, 1)
1527
+ if past_key_values:
1528
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1529
+
1530
+ # If we have inputs_embeds only for the first token
1531
+ if inputs_embeds is not None and past_key_values is None:
1532
+ model_inputs = {"inputs_embeds": inputs_embeds}
1533
+ else:
1534
+ model_inputs = {"input_ids": input_ids}
1535
+
1536
+ model_inputs.update(
1537
+ {
1538
+ "position_ids": position_ids,
1539
+ "past_key_values": past_key_values,
1540
+ "use_cache": kwargs.get("use_cache"),
1541
+ "attention_mask": attention_mask,
1542
+ }
1543
+ )
1544
+ return model_inputs
1545
+
1546
+ @staticmethod
1547
+ def _reorder_cache(past_key_values: Tuple, beam_idx: torch.Tensor) -> Tuple:
1548
+ reordered_past = ()
1549
+ for layer_past in past_key_values:
1550
+ reordered_past += (
1551
+ tuple(
1552
+ past_state.index_select(0, beam_idx.to(past_state.device))
1553
+ for past_state in layer_past
1554
+ ),
1555
+ )
1556
+ return reordered_past
1557
+
1558
+
1559
+ # ==============================================================================
1560
+ # For Sequence Classification
1561
+ # ==============================================================================
1562
+
1563
+ @add_start_docstrings(
1564
+ """
1565
+ DeepseekV3 with a sequence classification head on top. It typically uses the last token
1566
+ to perform the classification, similar to GPT-2 style classification.
1567
+ """,
1568
+ DeepseekV3_START_DOCSTRING,
1569
+ )
1570
+ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1571
+ def __init__(self, config: DeepseekV3Config):
1572
+ super().__init__(config)
1573
+ self.num_labels = config.num_labels
1574
+ self.model = DeepseekV3Model(config)
1575
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1576
+
1577
+ self.post_init()
1578
+
1579
+ def get_input_embeddings(self) -> nn.Embedding:
1580
+ return self.model.embed_tokens
1581
+
1582
+ def set_input_embeddings(self, value: nn.Embedding):
1583
+ self.model.embed_tokens = value
1584
+
1585
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1586
+ def forward(
1587
+ self,
1588
+ input_ids: Optional[torch.LongTensor] = None,
1589
+ attention_mask: Optional[torch.Tensor] = None,
1590
+ position_ids: Optional[torch.LongTensor] = None,
1591
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
1592
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1593
+ labels: Optional[torch.LongTensor] = None,
1594
+ use_cache: Optional[bool] = None,
1595
+ output_attentions: Optional[bool] = None,
1596
+ output_hidden_states: Optional[bool] = None,
1597
+ return_dict: Optional[bool] = None,
1598
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1599
+
1600
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1601
+ transformer_outputs = self.model(
1602
+ input_ids,
1603
+ attention_mask=attention_mask,
1604
+ position_ids=position_ids,
1605
+ past_key_values=past_key_values,
1606
+ inputs_embeds=inputs_embeds,
1607
+ use_cache=use_cache,
1608
+ output_attentions=output_attentions,
1609
+ output_hidden_states=output_hidden_states,
1610
+ return_dict=return_dict,
1611
+ )
1612
+ hidden_states = transformer_outputs[0]
1613
+ logits = self.score(hidden_states)
1614
+
1615
+ if input_ids is not None:
1616
+ batch_size = input_ids.shape[0]
1617
+ else:
1618
+ batch_size = inputs_embeds.shape[0]
1619
+
1620
+ # If no pad_token_id, assume last token for each sample
1621
+ if self.config.pad_token_id is None and batch_size != 1:
1622
+ raise ValueError(
1623
+ "Cannot handle batch sizes > 1 if no pad token is defined."
1624
+ )
1625
+ if self.config.pad_token_id is None:
1626
+ sequence_lengths = -1
1627
+ else:
1628
+ if input_ids is not None:
1629
+ sequence_lengths = (
1630
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1631
+ ).to(logits.device)
1632
+ else:
1633
+ sequence_lengths = -1
1634
+
1635
+ pooled_logits = logits[
1636
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1637
+ ]
1638
+
1639
+ loss = None
1640
+ if labels is not None:
1641
+ labels = labels.to(logits.device)
1642
+ if self.config.problem_type is None:
1643
+ if self.num_labels == 1:
1644
+ self.config.problem_type = "regression"
1645
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1646
+ self.config.problem_type = "single_label_classification"
1647
+ else:
1648
+ self.config.problem_type = "multi_label_classification"
1649
+
1650
+ if self.config.problem_type == "regression":
1651
+ loss_fct = MSELoss()
1652
+ if self.num_labels == 1:
1653
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1654
+ else:
1655
+ loss = loss_fct(pooled_logits, labels)
1656
+ elif self.config.problem_type == "single_label_classification":
1657
+ loss_fct = CrossEntropyLoss()
1658
+ loss = loss_fct(
1659
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1660
+ )
1661
+ elif self.config.problem_type == "multi_label_classification":
1662
+ loss_fct = BCEWithLogitsLoss()
1663
+ loss = loss_fct(pooled_logits, labels)
1664
+
1665
+ if not return_dict:
1666
+ output = (pooled_logits,) + transformer_outputs[1:]
1667
+ return ((loss,) + output) if loss is not None else output
1668
+
1669
+ return SequenceClassifierOutputWithPast(
1670
+ loss=loss,
1671
+ logits=pooled_logits,
1672
+ past_key_values=transformer_outputs.past_key_values,
1673
+ hidden_states=transformer_outputs.hidden_states,
1674
+ attentions=transformer_outputs.attentions,
1675
+ )