schuler commited on
Commit
0b1ba40
·
verified ·
1 Parent(s): 7916f9e

Create modeling_kphi3.py

Browse files
Files changed (1) hide show
  1. modeling_kphi3.py +1888 -0
modeling_kphi3.py ADDED
@@ -0,0 +1,1888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # This file modifies the original PHI3 Model from Microsoft. Please refer to
3
+ # https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
4
+ # for the original implementation.
5
+ # This implementation takes advantage of the (K) optimization described at:
6
+ # https://www.researchgate.net/publication/360226228_Grouped_Pointwise_Convolutions_Reduce_Parameters_in_Convolutional_Neural_Networks
7
+ # https://www.researchgate.net/publication/355214501_Grouped_Pointwise_Convolutions_Significantly_Reduces_Parameters_in_EfficientNet
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ """ PyTorch KPhi-3 model."""
22
+
23
+ import inspect
24
+ import math
25
+ import warnings
26
+ from typing import List, Optional, Tuple, Union
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.utils.checkpoint
31
+ from torch import nn
32
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
+
34
+ from transformers.activations import ACT2FN
35
+ from transformers.cache_utils import Cache, DynamicCache
36
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutputWithPast,
39
+ CausalLMOutputWithPast,
40
+ SequenceClassifierOutputWithPast,
41
+ TokenClassifierOutput,
42
+ )
43
+ from transformers.modeling_utils import PreTrainedModel
44
+ from transformers.utils import (
45
+ add_code_sample_docstrings,
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ is_flash_attn_2_available,
49
+ is_flash_attn_greater_or_equal_2_10,
50
+ logging,
51
+ replace_return_docstrings,
52
+ )
53
+ from .configuration_kphi3 import KPhi3Config
54
+
55
+ def get_max_acceptable_common_divisor(a, b, max_acceptable=1000000):
56
+ """
57
+ This is an inefficient max acceptable common divisor implementation to be improved.
58
+ # Arguments
59
+ a: is an integer.
60
+ b: is an integer.
61
+ max_acceptable: maximum acceptable common divisor.
62
+ """
63
+ divisor = max(1, min(a, b, max_acceptable))
64
+ while divisor > 0:
65
+ if a % divisor == 0 and b % divisor == 0:
66
+ return divisor
67
+ break
68
+ divisor -= 1
69
+
70
+ class InterleaveChannels(nn.Module):
71
+ """
72
+ This layer interleaves channels stepping according to the number passed as parameter.
73
+ This layer assumes "channel last".
74
+ """
75
+ def __init__(self, step_size=2, last_dim=3):
76
+ super().__init__()
77
+ self.step_size = step_size if step_size >= 2 else 1
78
+ self.last_dim = last_dim
79
+
80
+ def forward(self, x):
81
+ if (self.last_dim==3):
82
+ return torch.cat([x[:, :, :, shift_pos::self.step_size] for shift_pos in range(self.step_size)], dim=3)
83
+ else:
84
+ if (self.last_dim==2):
85
+ return torch.cat([x[:, :, shift_pos::self.step_size] for shift_pos in range(self.step_size)], dim=2)
86
+
87
+ def SignedSquareRoot1(x):
88
+ """
89
+ Custom activation function that implements:
90
+ f(x) = sqrt(x) for x > 1
91
+ f(x) = sqrt(-x) for x < -1
92
+ f(x) = x for -1 ≤ x ≤ 1
93
+ """
94
+ return torch.where(x > 1,
95
+ torch.sqrt(x),
96
+ torch.where(x < -1, torch.sqrt(-x), x)
97
+ )
98
+
99
+ # coded by GPT o1 Preview
100
+ class InterleaveChannelsFast(nn.Module):
101
+ """
102
+ This layer interleaves channels stepping according to the number passed as parameter.
103
+ This layer assumes "channel last".
104
+ """
105
+ def __init__(self, step_size=2, last_dim=3):
106
+ super().__init__()
107
+ self.step_size = max(step_size, 1)
108
+ self.last_dim = last_dim
109
+
110
+ def forward(self, x):
111
+ if self.last_dim == 3:
112
+ N, H, W, C = x.shape
113
+ if C % self.step_size != 0:
114
+ raise ValueError("Number of channels must be divisible by step_size")
115
+ # Reshape to separate the interleaving groups
116
+ x = x.view(N, H, W, self.step_size, C // self.step_size)
117
+ # Transpose to interleave the channels
118
+ x = x.permute(0, 1, 2, 4, 3)
119
+ # Flatten back to the original shape
120
+ x = x.reshape(N, H, W, C)
121
+ return x
122
+ elif self.last_dim == 2:
123
+ N, H, W = x.shape
124
+ if W % self.step_size != 0:
125
+ raise ValueError("Width must be divisible by step_size")
126
+ x = x.view(N, H, self.step_size, W // self.step_size)
127
+ x = x.permute(0, 1, 3, 2)
128
+ x = x.reshape(N, H, W)
129
+ return x
130
+ else:
131
+ raise ValueError("last_dim must be 2 or 3")
132
+
133
+ class GroupedLinear(nn.Module):
134
+ """
135
+ Similarly to a grouped pointwise convolution, this layer is a grouped linear layer.
136
+ This layer assumes "channel last".
137
+ """
138
+ def __init__(self, in_features, out_features, num_groups=1, bias=True):
139
+ super().__init__()
140
+ self.in_features = in_features
141
+ self.out_features = out_features
142
+ self.num_groups = num_groups
143
+ self.bias = bias
144
+
145
+ # Check if input features are divisible by num_groups
146
+ if in_features % num_groups != 0:
147
+ raise ValueError("Input features must be divisible by num_groups.")
148
+ if out_features % num_groups != 0:
149
+ raise ValueError("Output features must be divisible by num_groups.")
150
+
151
+ self.in_features_per_group = in_features // num_groups
152
+ self.out_features_per_group = out_features // num_groups
153
+
154
+ # Create individual linear layers for each group
155
+ self.group_layers = nn.ModuleList([
156
+ nn.Linear(self.in_features_per_group, self.out_features_per_group, bias=bias)
157
+ for _ in range(num_groups)
158
+ ])
159
+
160
+ def forward(self, x):
161
+ # print('input:',x.shape,' in:',self.in_features,' out:',self.out_features,
162
+ # ' groups:',self.num_groups,
163
+ # ' in_per_group:',self.in_features_per_group,
164
+ # ' out_per_group:',self.out_features_per_group,
165
+ # ' bias:',self.bias
166
+ #)
167
+ if self.in_features != x.shape[-1]:
168
+ raise ValueError(
169
+ "GroupedLinear error: "+
170
+ "expected in_feautures "+str(self.in_features)+
171
+ " but got "+str(x.shape[-1])
172
+ )
173
+ # Split the input tensor into groups along the last dimension
174
+ x_groups = x.chunk(self.num_groups, dim=-1)
175
+
176
+ # for i, tensor in enumerate(x_groups):
177
+ # print(f'x_groups[{i}]: {tensor.shape}')
178
+
179
+ # Apply individual linear layers to each group
180
+ out_groups = [layer(group) for layer, group in zip(self.group_layers, x_groups)]
181
+
182
+ # Concatenate the output groups along the last dimension
183
+ out = torch.cat(out_groups, dim=-1)
184
+ if self.out_features != out.shape[-1]:
185
+ raise ValueError(
186
+ "GroupedLinear error: "+
187
+ "expected out_feautures "+str(self.out_features)+
188
+ " but got "+str(out.shape[-1])
189
+ )
190
+ # print('output:',out.shape)
191
+ return out
192
+
193
+ class GroupedLinearFast(nn.Module):
194
+ """
195
+ Optimized grouped linear layer.
196
+ This layer assumes "channel last".
197
+ """
198
+ def __init__(self, in_features, out_features, num_groups=1, bias=True):
199
+ super().__init__()
200
+ self.in_features = in_features
201
+ self.out_features = out_features
202
+ self.num_groups = num_groups
203
+
204
+ # Validate divisibility
205
+ if in_features % num_groups != 0:
206
+ raise ValueError("Input features must be divisible by num_groups.")
207
+ if out_features % num_groups != 0:
208
+ raise ValueError("Output features must be divisible by num_groups.")
209
+
210
+ self.in_features_per_group = in_features // num_groups
211
+ self.out_features_per_group = out_features // num_groups
212
+
213
+ # Initialize weight and bias parameters
214
+ self.weight = nn.Parameter(
215
+ torch.Tensor(num_groups, self.in_features_per_group, self.out_features_per_group)
216
+ )
217
+ if bias:
218
+ self.bias = nn.Parameter(torch.Tensor(num_groups, self.out_features_per_group))
219
+ else:
220
+ self.register_parameter('bias', None)
221
+
222
+ self.reset_parameters()
223
+
224
+ def reset_parameters(self):
225
+ # Weight initialization
226
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
227
+ if self.bias is not None:
228
+ # Bias initialization
229
+ fan_in = self.in_features_per_group
230
+ bound = 1 / math.sqrt(fan_in)
231
+ nn.init.uniform_(self.bias, -bound, bound)
232
+
233
+ def forward(self, x):
234
+ in_shape = x.shape
235
+ if x.shape[-1] != self.in_features:
236
+ raise ValueError(
237
+ f"GroupedLinear error: expected in_features {self.in_features}, but got {x.shape[-1]}"
238
+ )
239
+ # Reshape input to separate groups
240
+ x = x.view(*x.shape[:-1], self.num_groups, self.in_features_per_group)
241
+
242
+ # print('in shape', in_shape)
243
+ # print('x shape', x.shape)
244
+ # print('weight shape', self.weight.shape)
245
+
246
+ # Perform batch matrix multiplication
247
+ # x shape: [..., num_groups, in_features_per_group]
248
+ # weight shape: [num_groups, in_features_per_group, out_features_per_group]
249
+ # out = torch.matmul(x, self.weight)
250
+ out = torch.einsum('...ni,niq->...nq', x, self.weight)
251
+ # print('out shape', out.shape)
252
+ # print(in_shape[0], in_shape[1], self.out_features)
253
+
254
+ # Add bias if present
255
+ if self.bias is not None:
256
+ out += self.bias
257
+
258
+ # Reshape output back to original shape
259
+ # out = out.view(in_shape[0], in_shape[1], self.out_features)
260
+ out = out.contiguous().view(*out.shape[:-2], self.out_features)
261
+ return out
262
+
263
+ class GroupedPointwiseConvolutionBlock(nn.Module):
264
+ """
265
+ This layer is composed by a grouped pointwise convolution followed by interleaving and another grouped pointwise comvolution with skip connection. This basic architecture can
266
+ vary according to the input tensor and its parameters. This is the basic building block for the papers:
267
+ https://www.researchgate.net/publication/360226228_Grouped_Pointwise_Convolutions_Reduce_Parameters_in_Convolutional_Neural_Networks
268
+ https://www.researchgate.net/publication/355214501_Grouped_Pointwise_Convolutions_Significantly_Reduces_Parameters_in_EfficientNet
269
+ This layer assumes "channel last".
270
+ """
271
+ def __init__(self, in_features, out_features, min_channels_per_group=32, last_dim=2, use_bias=False, activation=None, has_batch_norm=False, has_batch_scale=False):
272
+ super().__init__()
273
+ self.in_features = in_features
274
+ self.out_features = out_features
275
+ self.min_channels_per_group = min_channels_per_group
276
+ self.last_dim = last_dim
277
+ self.activation = activation
278
+ self.has_batch_norm = has_batch_norm
279
+ self.has_batch_scale = has_batch_scale
280
+ self.has_interleaving = False
281
+ self.use_bias = use_bias
282
+ self.grouped = False
283
+ self.second_conv = False
284
+ self.first_pointwise_conv = None
285
+ self.second_pointwise_conv = None
286
+ self.interleave_layer = None
287
+ # this is a hack to prevent runtime errors
288
+ self.weight = torch.Tensor(1, 1, 1)
289
+ self.bias = torch.Tensor(1, 1, 1)
290
+
291
+ prev_layer_channel_count = in_features
292
+ output_channel_count = out_features
293
+ max_acceptable_divisor = (prev_layer_channel_count//min_channels_per_group)
294
+ group_count = get_max_acceptable_common_divisor(prev_layer_channel_count, output_channel_count, max_acceptable = max_acceptable_divisor)
295
+ if group_count is None: group_count=1
296
+ self.output_group_size = output_channel_count // group_count
297
+
298
+ if (group_count>1):
299
+ self.grouped = True
300
+ self.first_pointwise_conv = GroupedLinearFast(in_features=in_features, out_features=out_features, num_groups=group_count, bias=use_bias)
301
+ if self.output_group_size > 1:
302
+ self.has_interleaving = True
303
+ self.interleave_layer = InterleaveChannelsFast(self.output_group_size, last_dim=last_dim)
304
+ if (prev_layer_channel_count >= output_channel_count):
305
+ # print('Has intergroup')
306
+ self.second_conv = True
307
+ self.second_pointwise_conv = GroupedLinearFast(in_features=out_features, out_features=out_features, num_groups=group_count, bias=use_bias)
308
+ else:
309
+ #print ('Dismissed groups:', group_count, 'Input channels:', prev_layer_channel_count, 'Output Channels:', output_channel_count, 'Input channels per group:', input_group_size, 'Output channels per group:', output_group_size)
310
+ self.first_pointwise_conv = GroupedLinear(in_features=in_features, out_features=out_features, num_groups=1, bias=use_bias)
311
+
312
+ def forward(self, x):
313
+ if (self.grouped):
314
+ output_tensor = self.first_pointwise_conv(x)
315
+ if self.activation is not None:
316
+ output_tensor = self.activation(output_tensor)
317
+ compression_tensor = output_tensor
318
+ if self.has_interleaving:
319
+ output_tensor = self.interleave_layer(output_tensor)
320
+ if self.second_conv:
321
+ output_tensor = self.second_pointwise_conv(output_tensor)
322
+ if self.activation is not None:
323
+ output_tensor = self.activation(output_tensor)
324
+ output_tensor = output_tensor + compression_tensor
325
+ else:
326
+ output_tensor = self.first_pointwise_conv(x)
327
+ if self.activation is not None:
328
+ output_tensor = self.activation(output_tensor)
329
+ return output_tensor
330
+
331
+
332
+ logger = logging.get_logger(__name__)
333
+
334
+ # Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
335
+ # if is_flash_attn_2_available():
336
+ _flash_supports_window_size = False
337
+ try:
338
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
339
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
340
+
341
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
342
+ except ImportError as error:
343
+ logger.warning(
344
+ f"`flash-attention` package not found, consider installing for better performance: {error}."
345
+ )
346
+ if not _flash_supports_window_size:
347
+ logger.warning(
348
+ "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
349
+ )
350
+
351
+ _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
352
+ _CONFIG_FOR_DOC = "KPhi3Config"
353
+
354
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
355
+ def _prepare_4d_causal_attention_mask_with_cache_position(
356
+ attention_mask: torch.Tensor,
357
+ sequence_length: int,
358
+ target_length: int,
359
+ dtype: torch.dtype,
360
+ device: torch.device,
361
+ min_dtype: float,
362
+ cache_position: torch.Tensor,
363
+ batch_size: int,
364
+ ):
365
+ """
366
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
367
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
368
+
369
+ Args:
370
+ attention_mask (`torch.Tensor`):
371
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
372
+ sequence_length (`int`):
373
+ The sequence length being processed.
374
+ target_length (`int`):
375
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
376
+ dtype (`torch.dtype`):
377
+ The dtype to use for the 4D attention mask.
378
+ device (`torch.device`):
379
+ The device to plcae the 4D attention mask on.
380
+ min_dtype (`float`):
381
+ The minimum value representable with the dtype `dtype`.
382
+ cache_position (`torch.Tensor`):
383
+ Indices depicting the position of the input sequence tokens in the sequence.
384
+ batch_size (`torch.Tensor`):
385
+ Batch size.
386
+ """
387
+ if attention_mask is not None and attention_mask.dim() == 4:
388
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
389
+ causal_mask = attention_mask
390
+ else:
391
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
392
+ if sequence_length != 1:
393
+ causal_mask = torch.triu(causal_mask, diagonal=1)
394
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
395
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
396
+ if attention_mask is not None:
397
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
398
+ mask_length = attention_mask.shape[-1]
399
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
400
+ padding_mask = padding_mask == 0
401
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
402
+ padding_mask, min_dtype
403
+ )
404
+
405
+ return causal_mask
406
+
407
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
408
+ class Phi3RMSNorm(nn.Module):
409
+ def __init__(self, hidden_size, eps=1e-6):
410
+ """
411
+ Phi3RMSNorm is equivalent to T5LayerNorm
412
+ """
413
+ super().__init__()
414
+ self.weight = nn.Parameter(torch.ones(hidden_size))
415
+ self.variance_epsilon = eps
416
+
417
+ def forward(self, hidden_states):
418
+ input_dtype = hidden_states.dtype
419
+ hidden_states = hidden_states.to(torch.float32)
420
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
421
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
422
+ return self.weight * hidden_states.to(input_dtype)
423
+
424
+ def extra_repr(self):
425
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
426
+
427
+
428
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
429
+ class Phi3RotaryEmbedding(nn.Module):
430
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
431
+ super().__init__()
432
+
433
+ self.dim = dim
434
+ self.max_position_embeddings = max_position_embeddings
435
+ self.base = base
436
+
437
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
438
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
439
+
440
+ @torch.no_grad()
441
+ def forward(self, x, position_ids, seq_len=None):
442
+ # x: [bs, num_attention_heads, seq_len, head_size]
443
+ self.inv_freq.to(x.device)
444
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
445
+ position_ids_expanded = position_ids[:, None, :].float()
446
+ # Force float32 since bfloat16 loses precision on long contexts
447
+ # See https://github.com/huggingface/transformers/pull/29285
448
+ device_type = x.device.type
449
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
450
+ with torch.autocast(device_type=device_type, enabled=False):
451
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
452
+ emb = torch.cat((freqs, freqs), dim=-1)
453
+ cos = emb.cos()
454
+ sin = emb.sin()
455
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
456
+
457
+
458
+ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
459
+ def __init__(self, dim, config, device=None):
460
+ warnings.warn(
461
+ "The class Phi3SuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please"
462
+ " use Phi3LongRoPEScaledRotaryEmbedding instead.",
463
+ FutureWarning,
464
+ )
465
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
466
+
467
+ self.short_factor = config.rope_scaling["short_factor"]
468
+ self.long_factor = config.rope_scaling["long_factor"]
469
+ self.original_max_position_embeddings = config.original_max_position_embeddings
470
+
471
+ @torch.no_grad()
472
+ def forward(self, x, position_ids, seq_len=None):
473
+ seq_len = torch.max(position_ids) + 1
474
+ if seq_len > self.original_max_position_embeddings:
475
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
476
+ else:
477
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
478
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
479
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
480
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
481
+ position_ids_expanded = position_ids[:, None, :].float()
482
+ # Force float32 since bfloat16 loses precision on long contexts
483
+ # See https://github.com/huggingface/transformers/pull/29285
484
+ device_type = x.device.type
485
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
486
+ with torch.autocast(device_type=device_type, enabled=False):
487
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
488
+ emb = torch.cat((freqs, freqs), dim=-1)
489
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
490
+ if scale <= 1.0:
491
+ scaling_factor = 1.0
492
+ else:
493
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
494
+ cos = emb.cos() * scaling_factor
495
+ sin = emb.sin() * scaling_factor
496
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
497
+
498
+
499
+ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
500
+ def __init__(self, dim, config, device=None):
501
+ warnings.warn(
502
+ "The class Phi3YarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers",
503
+ FutureWarning,
504
+ )
505
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
506
+
507
+ self.short_factor = config.rope_scaling["short_factor"]
508
+ self.long_factor = config.rope_scaling["long_factor"]
509
+ self.original_max_position_embeddings = config.original_max_position_embeddings
510
+
511
+ @torch.no_grad()
512
+ def forward(self, x, position_ids, seq_len=None):
513
+ seq_len = torch.max(position_ids) + 1
514
+ if seq_len > self.original_max_position_embeddings:
515
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
516
+ else:
517
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
518
+
519
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
520
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
521
+
522
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
523
+ position_ids_expanded = position_ids[:, None, :].float()
524
+
525
+ # Force float32 since bfloat16 loses precision on long contexts
526
+ # See https://github.com/huggingface/transformers/pull/29285
527
+ device_type = x.device.type
528
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
529
+ with torch.autocast(device_type=device_type, enabled=False):
530
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
531
+ emb = torch.cat((freqs, freqs), dim=-1)
532
+
533
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
534
+ if scale <= 1.0:
535
+ scaling_factor = 1.0
536
+ else:
537
+ scaling_factor = 0.1 * math.log(scale) + 1.0
538
+
539
+ cos = emb.cos() * scaling_factor
540
+ sin = emb.sin() * scaling_factor
541
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
542
+
543
+
544
+ class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
545
+ def __init__(self, dim, config, device=None):
546
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
547
+
548
+ self.short_factor = config.rope_scaling["short_factor"]
549
+ self.long_factor = config.rope_scaling["long_factor"]
550
+ self.original_max_position_embeddings = config.original_max_position_embeddings
551
+
552
+ @torch.no_grad()
553
+ def forward(self, x, position_ids, seq_len=None):
554
+ seq_len = torch.max(position_ids) + 1
555
+ if seq_len > self.original_max_position_embeddings:
556
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
557
+ else:
558
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
559
+
560
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
561
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
562
+
563
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
564
+ position_ids_expanded = position_ids[:, None, :].float()
565
+
566
+ # Force float32 since bfloat16 loses precision on long contexts
567
+ # See https://github.com/huggingface/transformers/pull/29285
568
+ device_type = x.device.type
569
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
570
+ with torch.autocast(device_type=device_type, enabled=False):
571
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
572
+ emb = torch.cat((freqs, freqs), dim=-1)
573
+
574
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
575
+ if scale <= 1.0:
576
+ scaling_factor = 1.0
577
+ else:
578
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
579
+
580
+ cos = emb.cos() * scaling_factor
581
+ sin = emb.sin() * scaling_factor
582
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
583
+
584
+
585
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
586
+ def rotate_half(x):
587
+ """Rotates half the hidden dims of the input."""
588
+ x1 = x[..., : x.shape[-1] // 2]
589
+ x2 = x[..., x.shape[-1] // 2 :]
590
+ return torch.cat((-x2, x1), dim=-1)
591
+
592
+
593
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
594
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
595
+ """Applies Rotary Position Embedding to the query and key tensors.
596
+
597
+ Args:
598
+ q (`torch.Tensor`): The query tensor.
599
+ k (`torch.Tensor`): The key tensor.
600
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
601
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
602
+ position_ids (`torch.Tensor`, *optional*):
603
+ Deprecated and unused.
604
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
605
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
606
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
607
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
608
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
609
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
610
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
611
+ Returns:
612
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
613
+ """
614
+ cos = cos.unsqueeze(unsqueeze_dim)
615
+ sin = sin.unsqueeze(unsqueeze_dim)
616
+ q_embed = (q * cos) + (rotate_half(q) * sin)
617
+ k_embed = (k * cos) + (rotate_half(k) * sin)
618
+ return q_embed, k_embed
619
+
620
+
621
+ class KPhi3MLP(nn.Module):
622
+ def __init__(self, config):
623
+ super().__init__()
624
+ self.config = config
625
+ self.activation_fn = ACT2FN[config.hidden_act]
626
+ if self.config.min_channels_per_group >= 0:
627
+ self.gate_up_proj = GroupedPointwiseConvolutionBlock(in_features=config.hidden_size, out_features=(2*config.intermediate_size), min_channels_per_group=self.config.min_channels_per_group , last_dim=2, use_bias=False)
628
+ self.down_proj = GroupedPointwiseConvolutionBlock(in_features=config.intermediate_size, out_features=config.hidden_size, min_channels_per_group=self.config.min_channels_per_group, last_dim=2, use_bias=False)
629
+ else:
630
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
631
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
632
+
633
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
634
+ up_states = self.gate_up_proj(hidden_states)
635
+
636
+ gate, up_states = up_states.chunk(2, dim=-1)
637
+ up_states = up_states * self.activation_fn(gate)
638
+
639
+ return self.down_proj(up_states)
640
+
641
+
642
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
643
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
644
+ """
645
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
646
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
647
+ """
648
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
649
+ if n_rep == 1:
650
+ return hidden_states
651
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
652
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
653
+
654
+
655
+ class KPhi3Attention(nn.Module):
656
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
657
+
658
+ def __init__(self, config: KPhi3Config, layer_idx: Optional[int] = None):
659
+ super().__init__()
660
+ self.config = config
661
+ self.layer_idx = layer_idx
662
+ if layer_idx is None:
663
+ logger.warning_once(
664
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
665
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
666
+ "when creating this class."
667
+ )
668
+
669
+ self.attention_dropout = config.attention_dropout
670
+ self.hidden_size = config.hidden_size
671
+ self.num_heads = config.num_attention_heads
672
+ self.head_dim = self.hidden_size // self.num_heads
673
+ self.num_key_value_heads = config.num_key_value_heads
674
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
675
+ self.max_position_embeddings = config.max_position_embeddings
676
+ self.original_max_position_embeddings = config.original_max_position_embeddings
677
+ self.rope_theta = config.rope_theta
678
+ self.rope_scaling = config.rope_scaling
679
+ self.is_causal = True
680
+
681
+ if (self.head_dim * self.num_heads) != self.hidden_size:
682
+ raise ValueError(
683
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
684
+ f" and `num_heads`: {self.num_heads})."
685
+ )
686
+
687
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
688
+ if self.config.min_channels_per_group >= 0:
689
+ self.o_proj = GroupedPointwiseConvolutionBlock(in_features=self.num_heads * self.head_dim, out_features=self.hidden_size, min_channels_per_group=self.config.min_channels_per_group, last_dim=2, use_bias=False)
690
+ self.qkv_proj = GroupedPointwiseConvolutionBlock(in_features=self.hidden_size, out_features=op_size, min_channels_per_group=self.config.min_channels_per_group, last_dim=2, use_bias=False)
691
+ else:
692
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
693
+ self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
694
+ self._init_rope()
695
+
696
+ def _init_rope(self):
697
+ if self.rope_scaling is None:
698
+ self.rotary_emb = Phi3RotaryEmbedding(
699
+ self.head_dim,
700
+ max_position_embeddings=self.max_position_embeddings,
701
+ base=self.rope_theta,
702
+ )
703
+ else:
704
+ scaling_type = self.config.rope_scaling["type"]
705
+ if scaling_type == "longrope":
706
+ self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
707
+ else:
708
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
709
+
710
+ def forward(
711
+ self,
712
+ hidden_states: torch.Tensor,
713
+ attention_mask: Optional[torch.Tensor] = None,
714
+ position_ids: Optional[torch.LongTensor] = None,
715
+ past_key_value: Optional[Cache] = None,
716
+ output_attentions: bool = False,
717
+ use_cache: bool = False,
718
+ cache_position: Optional[torch.LongTensor] = None,
719
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
720
+ logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.")
721
+
722
+ bsz, q_len, _ = hidden_states.size()
723
+
724
+ qkv = self.qkv_proj(hidden_states)
725
+ query_pos = self.num_heads * self.head_dim
726
+ query_states = qkv[..., :query_pos]
727
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
728
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
729
+
730
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
731
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
732
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
733
+
734
+ kv_seq_len = key_states.shape[-2]
735
+ if past_key_value is not None:
736
+ if self.layer_idx is None:
737
+ raise ValueError(
738
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
739
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
740
+ "with a layer index."
741
+ )
742
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
743
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
744
+
745
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
746
+
747
+ if past_key_value is not None:
748
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
749
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
750
+
751
+ # repeat k/v heads if n_kv_heads < n_heads
752
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
753
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
754
+
755
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
756
+
757
+ if attention_mask is not None:
758
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
759
+ attn_weights += causal_mask
760
+
761
+ # upcast attention to fp32
762
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
763
+ # attn_weights = SignedSquareRoot1(attn_weights.to(value_states.dtype))
764
+ # value_states = SignedSquareRoot1(value_states)
765
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
766
+
767
+ attn_output = torch.matmul(attn_weights, value_states)
768
+ # attn_output = SignedSquareRoot1(attn_output)
769
+
770
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
771
+ raise ValueError(
772
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
773
+ f" {attn_output.size()}"
774
+ )
775
+
776
+ attn_output = attn_output.transpose(1, 2).contiguous()
777
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
778
+
779
+ attn_output = self.o_proj(attn_output)
780
+ # attn_output = SignedSquareRoot1(attn_output)
781
+
782
+ if not output_attentions:
783
+ attn_weights = None
784
+
785
+ return attn_output, attn_weights, past_key_value
786
+
787
+
788
+ class KPhi3FlashAttention2(KPhi3Attention):
789
+ """
790
+ KPhi-3 flash attention module. This module inherits from `KPhi3Attention` as the weights of the module stays
791
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
792
+ flash attention and deal with padding tokens in case the input contains any of them.
793
+ """
794
+
795
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
796
+ def __init__(self, *args, **kwargs):
797
+ super().__init__(*args, **kwargs)
798
+
799
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
800
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
801
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
802
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
803
+
804
+ def forward(
805
+ self,
806
+ hidden_states: torch.Tensor,
807
+ attention_mask: Optional[torch.LongTensor] = None,
808
+ position_ids: Optional[torch.LongTensor] = None,
809
+ past_key_value: Optional[Cache] = None,
810
+ output_attentions: bool = False,
811
+ use_cache: bool = False,
812
+ cache_position: Optional[torch.LongTensor] = None,
813
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
814
+ # Phi3FlashAttention2 attention does not support output_attentions
815
+
816
+ output_attentions = False
817
+
818
+ bsz, q_len, _ = hidden_states.size()
819
+
820
+ qkv = self.qkv_proj(hidden_states)
821
+ query_pos = self.num_heads * self.head_dim
822
+ query_states = qkv[..., :query_pos]
823
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
824
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
825
+
826
+ # Flash attention requires the input to have the shape
827
+ # batch_size x seq_length x head_dim x hidden_dim
828
+ # therefore we just need to keep the original shape
829
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
830
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
831
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
832
+
833
+ kv_seq_len = key_states.shape[-2]
834
+ if past_key_value is not None:
835
+ if self.layer_idx is None:
836
+ raise ValueError(
837
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
838
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
839
+ "with a layer index."
840
+ )
841
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
842
+
843
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
844
+ rotary_seq_len = (
845
+ max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
846
+ )
847
+
848
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids)
849
+
850
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
851
+
852
+ if past_key_value is not None:
853
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
854
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
855
+ if (
856
+ getattr(self.config, "sliding_window", None) is not None
857
+ and kv_seq_len > self.config.sliding_window
858
+ and cache_has_contents
859
+ ):
860
+ slicing_tokens = 1 - self.config.sliding_window
861
+
862
+ past_key = past_key_value[self.layer_idx][0]
863
+ past_value = past_key_value[self.layer_idx][1]
864
+
865
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
866
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
867
+
868
+ if past_key.shape[-2] != self.config.sliding_window - 1:
869
+ raise ValueError(
870
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
871
+ f" {past_key.shape}"
872
+ )
873
+
874
+ if attention_mask is not None:
875
+ attention_mask = attention_mask[:, slicing_tokens:]
876
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
877
+
878
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
879
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
880
+
881
+ # repeat k/v heads if n_kv_heads < n_heads
882
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
883
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
884
+
885
+ attn_dropout = self.attention_dropout if self.training else 0.0
886
+
887
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
888
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
889
+ # cast them back in the correct dtype just to be sure everything works as expected.
890
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
891
+ # in fp32.
892
+
893
+ if query_states.dtype == torch.float32:
894
+ if torch.is_autocast_enabled():
895
+ target_dtype = torch.get_autocast_gpu_dtype()
896
+ # Handle the case where the model is quantized
897
+ elif hasattr(self.config, "_pre_quantization_dtype"):
898
+ target_dtype = self.config._pre_quantization_dtype
899
+ else:
900
+ target_dtype = self.qkv_proj.weight.dtype
901
+
902
+ logger.warning_once(
903
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
904
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
905
+ f" {target_dtype}."
906
+ )
907
+
908
+ query_states = query_states.to(target_dtype)
909
+ key_states = key_states.to(target_dtype)
910
+ value_states = value_states.to(target_dtype)
911
+
912
+ # Reashape to the expected shape for Flash Attention
913
+ query_states = query_states.transpose(1, 2)
914
+ key_states = key_states.transpose(1, 2)
915
+ value_states = value_states.transpose(1, 2)
916
+
917
+ attn_output = _flash_attention_forward(
918
+ query_states,
919
+ key_states,
920
+ value_states,
921
+ attention_mask,
922
+ q_len,
923
+ position_ids=position_ids,
924
+ dropout=attn_dropout,
925
+ sliding_window=getattr(self.config, "sliding_window", None),
926
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
927
+ is_causal=self.is_causal,
928
+ )
929
+
930
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
931
+ attn_output = self.o_proj(attn_output)
932
+
933
+ if not output_attentions:
934
+ attn_weights = None
935
+
936
+ return attn_output, attn_weights, past_key_value
937
+
938
+
939
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
940
+ # TODO @Arthur no longer copied from LLama after static cache
941
+ class KPhi3SdpaAttention(KPhi3Attention):
942
+ """
943
+ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
944
+ `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
945
+ SDPA API.
946
+ """
947
+
948
+ # Adapted from Phi3Attention.forward
949
+ def forward(
950
+ self,
951
+ hidden_states: torch.Tensor,
952
+ attention_mask: Optional[torch.Tensor] = None,
953
+ position_ids: Optional[torch.LongTensor] = None,
954
+ past_key_value: Optional[Cache] = None,
955
+ output_attentions: bool = False,
956
+ use_cache: bool = False,
957
+ cache_position: Optional[torch.LongTensor] = None,
958
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
959
+ if output_attentions:
960
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
961
+ logger.warning_once(
962
+ "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
963
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
964
+ )
965
+ return super().forward(
966
+ hidden_states=hidden_states,
967
+ attention_mask=attention_mask,
968
+ position_ids=position_ids,
969
+ past_key_value=past_key_value,
970
+ output_attentions=output_attentions,
971
+ use_cache=use_cache,
972
+ )
973
+
974
+ bsz, q_len, _ = hidden_states.size()
975
+
976
+ qkv = self.qkv_proj(hidden_states)
977
+ query_pos = self.num_heads * self.head_dim
978
+ query_states = qkv[..., :query_pos]
979
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
980
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
981
+
982
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
983
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
984
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
985
+
986
+ kv_seq_len = key_states.shape[-2]
987
+ if past_key_value is not None:
988
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
989
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
990
+
991
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
992
+
993
+ if past_key_value is not None:
994
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
995
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
996
+
997
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
998
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
999
+
1000
+ causal_mask = attention_mask
1001
+ if attention_mask is not None:
1002
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
1003
+
1004
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1005
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
1006
+ if query_states.device.type == "cuda" and attention_mask is not None:
1007
+ query_states = query_states.contiguous()
1008
+ key_states = key_states.contiguous()
1009
+ value_states = value_states.contiguous()
1010
+
1011
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1012
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1013
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1014
+ is_causal = True if causal_mask is None and q_len > 1 else False
1015
+
1016
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1017
+ query_states,
1018
+ key_states,
1019
+ value_states,
1020
+ attn_mask=causal_mask,
1021
+ dropout_p=self.attention_dropout if self.training else 0.0,
1022
+ is_causal=is_causal,
1023
+ )
1024
+
1025
+ attn_output = attn_output.transpose(1, 2).contiguous()
1026
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1027
+
1028
+ attn_output = self.o_proj(attn_output)
1029
+
1030
+ return attn_output, None, past_key_value
1031
+
1032
+
1033
+ KPHI3_ATTENTION_CLASSES = {
1034
+ "eager": KPhi3Attention,
1035
+ "flash_attention_2": KPhi3FlashAttention2,
1036
+ "sdpa": KPhi3SdpaAttention,
1037
+ }
1038
+
1039
+
1040
+ class KPhi3DecoderLayer(nn.Module):
1041
+ def __init__(self, config: KPhi3Config, layer_idx: int):
1042
+ super().__init__()
1043
+
1044
+ self.config = config
1045
+ self.self_attn = KPHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
1046
+
1047
+ self.mlp = KPhi3MLP(config)
1048
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1049
+
1050
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
1051
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
1052
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1053
+
1054
+ def forward(
1055
+ self,
1056
+ hidden_states: torch.Tensor,
1057
+ attention_mask: Optional[torch.Tensor] = None,
1058
+ position_ids: Optional[torch.LongTensor] = None,
1059
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1060
+ output_attentions: Optional[bool] = False,
1061
+ use_cache: Optional[bool] = False,
1062
+ cache_position: Optional[torch.LongTensor] = None,
1063
+ **kwargs,
1064
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1065
+ """
1066
+ Args:
1067
+ hidden_states (`torch.FloatTensor`):
1068
+ input to the layer of shape `(batch, seq_len, embed_dim)`
1069
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1070
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
1071
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
1072
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
1073
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
1074
+ output_attentions (`bool`, *optional*):
1075
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1076
+ returned tensors for more detail.
1077
+ use_cache (`bool`, *optional*):
1078
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1079
+ (see `past_key_values`).
1080
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1081
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1082
+ Indices depicting the position of the input sequence tokens in the sequence
1083
+ kwargs (`dict`, *optional*):
1084
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
1085
+ into the model
1086
+ """
1087
+
1088
+ residual = hidden_states
1089
+
1090
+ hidden_states = self.input_layernorm(hidden_states)
1091
+
1092
+ # Self Attention
1093
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
1094
+ hidden_states=hidden_states,
1095
+ attention_mask=attention_mask,
1096
+ position_ids=position_ids,
1097
+ past_key_value=past_key_value,
1098
+ output_attentions=output_attentions,
1099
+ use_cache=use_cache,
1100
+ cache_position=cache_position,
1101
+ )
1102
+
1103
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
1104
+
1105
+ residual = hidden_states
1106
+ hidden_states = self.post_attention_layernorm(hidden_states)
1107
+ hidden_states = self.mlp(hidden_states)
1108
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
1109
+
1110
+ outputs = (hidden_states,)
1111
+
1112
+ if output_attentions:
1113
+ outputs += (self_attn_weights,)
1114
+
1115
+ if use_cache:
1116
+ outputs += (present_key_value,)
1117
+
1118
+ return outputs
1119
+
1120
+
1121
+ KPHI3_START_DOCSTRING = r"""
1122
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1123
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1124
+ etc.)
1125
+
1126
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1127
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1128
+ and behavior.
1129
+
1130
+ Parameters:
1131
+ config ([`KPhi3Config`]):
1132
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1133
+ load the weights associated with the model, only the configuration. Check out the
1134
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1135
+ """
1136
+
1137
+
1138
+ @add_start_docstrings(
1139
+ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
1140
+ KPHI3_START_DOCSTRING,
1141
+ )
1142
+ class KPhi3PreTrainedModel(PreTrainedModel, GenerationMixin):
1143
+ config_class = KPhi3Config
1144
+ base_model_prefix = "model"
1145
+ supports_gradient_checkpointing = True
1146
+ _no_split_modules = ["KPhi3DecoderLayer"]
1147
+ _skip_keys_device_placement = "past_key_values"
1148
+ _supports_flash_attn_2 = True
1149
+ _supports_sdpa = False
1150
+ _supports_cache_class = True
1151
+
1152
+ _version = "1.0.0"
1153
+
1154
+ def _init_weights(self, module):
1155
+ std = self.config.initializer_range
1156
+ if isinstance(module, nn.Linear):
1157
+ module.weight.data.normal_(mean=0.0, std=std)
1158
+ if module.bias is not None:
1159
+ module.bias.data.zero_()
1160
+ elif isinstance(module, nn.Embedding):
1161
+ module.weight.data.normal_(mean=0.0, std=std)
1162
+ if module.padding_idx is not None:
1163
+ module.weight.data[module.padding_idx].zero_()
1164
+
1165
+
1166
+ KPHI3_INPUTS_DOCSTRING = r"""
1167
+ Args:
1168
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1169
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1170
+ it.
1171
+
1172
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1173
+ [`PreTrainedTokenizer.__call__`] for details.
1174
+
1175
+ [What are input IDs?](../glossary#input-ids)
1176
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1177
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1178
+
1179
+ - 1 for tokens that are **not masked**,
1180
+ - 0 for tokens that are **masked**.
1181
+
1182
+ [What are attention masks?](../glossary#attention-mask)
1183
+
1184
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1185
+ [`PreTrainedTokenizer.__call__`] for details.
1186
+
1187
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1188
+ `past_key_values`).
1189
+
1190
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1191
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1192
+ information on the default strategy.
1193
+
1194
+ - 1 indicates the head is **not masked**,
1195
+ - 0 indicates the head is **masked**.
1196
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1197
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1198
+ config.n_positions - 1]`.
1199
+
1200
+ [What are position IDs?](../glossary#position-ids)
1201
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1202
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1203
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1204
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1205
+
1206
+ Two formats are allowed:
1207
+ - a [`~cache_utils.Cache`] instance;
1208
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1209
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1210
+ cache format.
1211
+
1212
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1213
+ legacy cache format will be returned.
1214
+
1215
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1216
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1217
+ of shape `(batch_size, sequence_length)`.
1218
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1219
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1220
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1221
+ model's internal embedding lookup matrix.
1222
+ use_cache (`bool`, *optional*):
1223
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1224
+ `past_key_values`).
1225
+ output_attentions (`bool`, *optional*):
1226
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1227
+ tensors for more detail.
1228
+ output_hidden_states (`bool`, *optional*):
1229
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1230
+ more detail.
1231
+ return_dict (`bool`, *optional*):
1232
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1233
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1234
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1235
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1236
+ the complete sequence length.
1237
+ """
1238
+
1239
+
1240
+ @add_start_docstrings(
1241
+ "The bare KPhi-3 model outputting raw hidden-states without any specific head on top.",
1242
+ KPHI3_START_DOCSTRING,
1243
+ )
1244
+ class KPhi3Model(KPhi3PreTrainedModel):
1245
+ """
1246
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`KPhi3DecoderLayer`]
1247
+
1248
+ Args:
1249
+ config: KPhi3Config
1250
+ """
1251
+
1252
+ def __init__(self, config: KPhi3Config):
1253
+ super().__init__(config)
1254
+ self.padding_idx = config.pad_token_id
1255
+ self.vocab_size = config.vocab_size
1256
+ self.activation_fn = ACT2FN[config.hidden_act]
1257
+
1258
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_size, self.padding_idx)
1259
+ if config.embed_size != config.hidden_size:
1260
+ self.embed_to_hidden = GroupedPointwiseConvolutionBlock(config.embed_size, config.hidden_size, config.min_channels_per_group)
1261
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
1262
+ self.layers = nn.ModuleList(
1263
+ [KPhi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1264
+ )
1265
+ self._attn_implementation = config._attn_implementation
1266
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1267
+
1268
+ self.gradient_checkpointing = False
1269
+ # Initialize weights and apply final processing
1270
+ self.post_init()
1271
+
1272
+ def get_input_embeddings(self):
1273
+ return self.embed_tokens
1274
+
1275
+ def set_input_embeddings(self, value):
1276
+ self.embed_tokens = value
1277
+
1278
+ @add_start_docstrings_to_model_forward(KPHI3_INPUTS_DOCSTRING)
1279
+ def forward(
1280
+ self,
1281
+ input_ids: torch.LongTensor = None,
1282
+ attention_mask: Optional[torch.Tensor] = None,
1283
+ position_ids: Optional[torch.LongTensor] = None,
1284
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1285
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1286
+ use_cache: Optional[bool] = None,
1287
+ output_attentions: Optional[bool] = None,
1288
+ output_hidden_states: Optional[bool] = None,
1289
+ return_dict: Optional[bool] = None,
1290
+ cache_position: Optional[torch.LongTensor] = None,
1291
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1292
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1293
+ output_hidden_states = (
1294
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1295
+ )
1296
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1297
+
1298
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1299
+
1300
+ if (input_ids is None) ^ (inputs_embeds is not None):
1301
+ raise ValueError(
1302
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1303
+ )
1304
+
1305
+ if self.gradient_checkpointing and self.training:
1306
+ if use_cache:
1307
+ logger.warning_once(
1308
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1309
+ )
1310
+ use_cache = False
1311
+
1312
+ use_legacy_cache = False
1313
+ if use_cache and not isinstance(past_key_values, Cache) and not self.training:
1314
+ use_legacy_cache = True
1315
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1316
+ logger.warning_once(
1317
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
1318
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
1319
+ )
1320
+
1321
+ if inputs_embeds is None:
1322
+ inputs_embeds = self.embed_tokens(input_ids)
1323
+
1324
+ if cache_position is None:
1325
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1326
+ cache_position = torch.arange(
1327
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1328
+ )
1329
+ if position_ids is None:
1330
+ position_ids = cache_position.unsqueeze(0)
1331
+
1332
+ causal_mask = self._update_causal_mask(
1333
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1334
+ )
1335
+ inputs_embeds = self.embed_dropout(inputs_embeds)
1336
+
1337
+ if self.config.embed_size != self.config.hidden_size:
1338
+ hidden_states = self.embed_to_hidden(inputs_embeds)
1339
+ # hidden_states = self.activation_fn(self.embed_to_hidden(inputs_embeds))
1340
+ else:
1341
+ hidden_states = inputs_embeds
1342
+
1343
+ # decoder layers
1344
+ all_hidden_states = () if output_hidden_states else None
1345
+ all_self_attns = () if output_attentions else None
1346
+ next_decoder_cache = None
1347
+
1348
+ for decoder_layer in self.layers:
1349
+ if output_hidden_states:
1350
+ all_hidden_states += (hidden_states,)
1351
+
1352
+ if self.gradient_checkpointing and self.training:
1353
+ layer_outputs = self._gradient_checkpointing_func(
1354
+ decoder_layer.__call__,
1355
+ hidden_states,
1356
+ causal_mask,
1357
+ position_ids,
1358
+ past_key_values,
1359
+ output_attentions,
1360
+ use_cache,
1361
+ cache_position,
1362
+ )
1363
+ else:
1364
+ layer_outputs = decoder_layer(
1365
+ hidden_states,
1366
+ attention_mask=causal_mask,
1367
+ position_ids=position_ids,
1368
+ past_key_value=past_key_values,
1369
+ output_attentions=output_attentions,
1370
+ use_cache=use_cache,
1371
+ cache_position=cache_position,
1372
+ )
1373
+
1374
+ hidden_states = layer_outputs[0]
1375
+
1376
+ if use_cache:
1377
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1378
+
1379
+ if output_attentions:
1380
+ all_self_attns += (layer_outputs[1],)
1381
+
1382
+ hidden_states = self.norm(hidden_states)
1383
+
1384
+ # add hidden states from the last decoder layer
1385
+ if output_hidden_states:
1386
+ all_hidden_states += (hidden_states,)
1387
+
1388
+ next_cache = None
1389
+ if use_cache:
1390
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1391
+ if not return_dict:
1392
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1393
+ return BaseModelOutputWithPast(
1394
+ last_hidden_state=hidden_states,
1395
+ past_key_values=next_cache,
1396
+ hidden_states=all_hidden_states,
1397
+ attentions=all_self_attns,
1398
+ )
1399
+
1400
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1401
+ def _update_causal_mask(
1402
+ self,
1403
+ attention_mask: torch.Tensor,
1404
+ input_tensor: torch.Tensor,
1405
+ cache_position: torch.Tensor,
1406
+ past_key_values: Cache,
1407
+ output_attentions: bool,
1408
+ ):
1409
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1410
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1411
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1412
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1413
+
1414
+ if self.config._attn_implementation == "flash_attention_2":
1415
+ if attention_mask is not None and 0.0 in attention_mask:
1416
+ return attention_mask
1417
+ return None
1418
+
1419
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1420
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1421
+ # to infer the attention mask.
1422
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1423
+ using_static_cache = isinstance(past_key_values, StaticCache)
1424
+
1425
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1426
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1427
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1428
+ attention_mask,
1429
+ inputs_embeds=input_tensor,
1430
+ past_key_values_length=past_seen_tokens,
1431
+ is_training=self.training,
1432
+ ):
1433
+ return None
1434
+
1435
+ dtype, device = input_tensor.dtype, input_tensor.device
1436
+ min_dtype = torch.finfo(dtype).min
1437
+ sequence_length = input_tensor.shape[1]
1438
+ if using_static_cache:
1439
+ target_length = past_key_values.get_max_length()
1440
+ else:
1441
+ target_length = (
1442
+ attention_mask.shape[-1]
1443
+ if isinstance(attention_mask, torch.Tensor)
1444
+ else past_seen_tokens + sequence_length + 1
1445
+ )
1446
+
1447
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1448
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1449
+ attention_mask,
1450
+ sequence_length=sequence_length,
1451
+ target_length=target_length,
1452
+ dtype=dtype,
1453
+ device=device,
1454
+ min_dtype=min_dtype,
1455
+ cache_position=cache_position,
1456
+ batch_size=input_tensor.shape[0],
1457
+ )
1458
+
1459
+ if (
1460
+ self.config._attn_implementation == "sdpa"
1461
+ and attention_mask is not None
1462
+ and attention_mask.device.type == "cuda"
1463
+ and not output_attentions
1464
+ ):
1465
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1466
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1467
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1468
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1469
+
1470
+ return causal_mask
1471
+
1472
+
1473
+ class KPhi3ForCausalLM(KPhi3PreTrainedModel):
1474
+ _tied_weights_keys = ["lm_head.weight"]
1475
+
1476
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
1477
+ def __init__(self, config):
1478
+ super().__init__(config)
1479
+ self.model = KPhi3Model(config)
1480
+ self.vocab_size = config.vocab_size
1481
+ if config.embed_size != config.hidden_size:
1482
+ self.hidden_to_embed = GroupedPointwiseConvolutionBlock(config.hidden_size, config.embed_size, config.min_channels_per_group)
1483
+ self.lm_head = nn.Linear(config.embed_size, config.vocab_size, bias=False)
1484
+
1485
+ # Initialize weights and apply final processing
1486
+ self.post_init()
1487
+
1488
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1489
+ def get_input_embeddings(self):
1490
+ return self.model.embed_tokens
1491
+
1492
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1493
+ def set_input_embeddings(self, value):
1494
+ self.model.embed_tokens = value
1495
+
1496
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1497
+ def get_output_embeddings(self):
1498
+ return self.lm_head
1499
+
1500
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1501
+ def set_output_embeddings(self, new_embeddings):
1502
+ self.lm_head = new_embeddings
1503
+
1504
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1505
+ def set_decoder(self, decoder):
1506
+ self.model = decoder
1507
+
1508
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1509
+ def get_decoder(self):
1510
+ return self.model
1511
+
1512
+ # Ignore copy
1513
+ @add_start_docstrings_to_model_forward(KPHI3_INPUTS_DOCSTRING)
1514
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1515
+ def forward(
1516
+ self,
1517
+ input_ids: torch.LongTensor = None,
1518
+ attention_mask: Optional[torch.Tensor] = None,
1519
+ position_ids: Optional[torch.LongTensor] = None,
1520
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1521
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1522
+ labels: Optional[torch.LongTensor] = None,
1523
+ use_cache: Optional[bool] = None,
1524
+ output_attentions: Optional[bool] = None,
1525
+ output_hidden_states: Optional[bool] = None,
1526
+ return_dict: Optional[bool] = None,
1527
+ cache_position: Optional[torch.LongTensor] = None,
1528
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1529
+ r"""
1530
+ Args:
1531
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1532
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1533
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1534
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1535
+
1536
+ Returns:
1537
+
1538
+ Example:
1539
+
1540
+ ```python
1541
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1542
+
1543
+ >>> model = KPhi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1544
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1545
+
1546
+ >>> prompt = "This is an example script ."
1547
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1548
+
1549
+ >>> # Generate
1550
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1551
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1552
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
1553
+ ```"""
1554
+
1555
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1556
+ output_hidden_states = (
1557
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1558
+ )
1559
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1560
+
1561
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1562
+ outputs = self.model(
1563
+ input_ids=input_ids,
1564
+ attention_mask=attention_mask,
1565
+ position_ids=position_ids,
1566
+ past_key_values=past_key_values,
1567
+ inputs_embeds=inputs_embeds,
1568
+ use_cache=use_cache,
1569
+ output_attentions=output_attentions,
1570
+ output_hidden_states=output_hidden_states,
1571
+ return_dict=return_dict,
1572
+ )
1573
+
1574
+ hidden_states = outputs[0]
1575
+ if self.config.embed_size != self.config.hidden_size:
1576
+ hidden_states = self.hidden_to_embed(hidden_states)
1577
+ logits = self.lm_head(hidden_states)
1578
+ logits = logits.float()
1579
+
1580
+ loss = None
1581
+ if labels is not None:
1582
+ # Shift so that tokens < n predict n
1583
+ shift_logits = logits[..., :-1, :].contiguous()
1584
+ shift_labels = labels[..., 1:].contiguous()
1585
+ # Flatten the tokens
1586
+ loss_fct = CrossEntropyLoss()
1587
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1588
+ shift_labels = shift_labels.view(-1)
1589
+ # Enable model parallelism
1590
+ shift_labels = shift_labels.to(shift_logits.device)
1591
+ loss = loss_fct(shift_logits, shift_labels)
1592
+
1593
+ if not return_dict:
1594
+ output = (logits,) + outputs[1:]
1595
+ return (loss,) + output if loss is not None else output
1596
+
1597
+ return CausalLMOutputWithPast(
1598
+ loss=loss,
1599
+ logits=logits,
1600
+ past_key_values=outputs.past_key_values,
1601
+ hidden_states=outputs.hidden_states,
1602
+ attentions=outputs.attentions,
1603
+ )
1604
+
1605
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1606
+ def prepare_inputs_for_generation(
1607
+ self,
1608
+ input_ids,
1609
+ past_key_values=None,
1610
+ attention_mask=None,
1611
+ inputs_embeds=None,
1612
+ cache_position=None,
1613
+ position_ids=None,
1614
+ use_cache=True,
1615
+ **kwargs,
1616
+ ):
1617
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1618
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1619
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1620
+ if past_key_values is not None:
1621
+ if inputs_embeds is not None: # Exception 1
1622
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1623
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1624
+ input_ids = input_ids[:, cache_position]
1625
+
1626
+ if attention_mask is not None and position_ids is None:
1627
+ # create position_ids on the fly for batch generation
1628
+ position_ids = attention_mask.long().cumsum(-1) - 1
1629
+ position_ids.masked_fill_(attention_mask == 0, 1)
1630
+ if past_key_values:
1631
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1632
+
1633
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1634
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1635
+
1636
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1637
+ if inputs_embeds is not None and cache_position[0] == 0:
1638
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1639
+ else:
1640
+ # The clone here is for the same reason as for `position_ids`.
1641
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
1642
+
1643
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1644
+ if model_inputs["inputs_embeds"] is not None:
1645
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1646
+ device = model_inputs["inputs_embeds"].device
1647
+ else:
1648
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1649
+ device = model_inputs["input_ids"].device
1650
+
1651
+ dtype = self.lm_head.weight.dtype
1652
+ min_dtype = torch.finfo(dtype).min
1653
+
1654
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1655
+ attention_mask,
1656
+ sequence_length=sequence_length,
1657
+ target_length=past_key_values.get_max_length(),
1658
+ dtype=dtype,
1659
+ device=device,
1660
+ min_dtype=min_dtype,
1661
+ cache_position=cache_position,
1662
+ batch_size=batch_size,
1663
+ )
1664
+
1665
+ model_inputs.update(
1666
+ {
1667
+ "position_ids": position_ids,
1668
+ "cache_position": cache_position,
1669
+ "past_key_values": past_key_values,
1670
+ "use_cache": use_cache,
1671
+ "attention_mask": attention_mask,
1672
+ }
1673
+ )
1674
+ return model_inputs
1675
+
1676
+
1677
+ @add_start_docstrings(
1678
+ """
1679
+ The [`KPhi3Model`] with a sequence classification head on top (linear layer).
1680
+
1681
+ [`KPhi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1682
+ (e.g. GPT-2) do.
1683
+
1684
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1685
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1686
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1687
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1688
+ each row of the batch).
1689
+ """,
1690
+ KPHI3_START_DOCSTRING,
1691
+ )
1692
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
1693
+ class KPhi3ForSequenceClassification(KPhi3PreTrainedModel):
1694
+ def __init__(self, config):
1695
+ super().__init__(config)
1696
+ self.num_labels = config.num_labels
1697
+ self.model = KPhi3Model(config)
1698
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1699
+
1700
+ # Initialize weights and apply final processing
1701
+ self.post_init()
1702
+
1703
+ def get_input_embeddings(self):
1704
+ return self.model.embed_tokens
1705
+
1706
+ def set_input_embeddings(self, value):
1707
+ self.model.embed_tokens = value
1708
+
1709
+ @add_start_docstrings_to_model_forward(KPHI3_INPUTS_DOCSTRING)
1710
+ def forward(
1711
+ self,
1712
+ input_ids: Optional[torch.LongTensor] = None,
1713
+ attention_mask: Optional[torch.Tensor] = None,
1714
+ position_ids: Optional[torch.LongTensor] = None,
1715
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1716
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1717
+ labels: Optional[torch.LongTensor] = None,
1718
+ use_cache: Optional[bool] = None,
1719
+ output_attentions: Optional[bool] = None,
1720
+ output_hidden_states: Optional[bool] = None,
1721
+ return_dict: Optional[bool] = None,
1722
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1723
+ r"""
1724
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1725
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1726
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1727
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1728
+ """
1729
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1730
+
1731
+ model_outputs = self.model(
1732
+ input_ids,
1733
+ attention_mask=attention_mask,
1734
+ position_ids=position_ids,
1735
+ past_key_values=past_key_values,
1736
+ inputs_embeds=inputs_embeds,
1737
+ use_cache=use_cache,
1738
+ output_attentions=output_attentions,
1739
+ output_hidden_states=output_hidden_states,
1740
+ return_dict=return_dict,
1741
+ )
1742
+ hidden_states = model_outputs[0]
1743
+ logits = self.score(hidden_states)
1744
+
1745
+ if input_ids is not None:
1746
+ batch_size = input_ids.shape[0]
1747
+ else:
1748
+ batch_size = inputs_embeds.shape[0]
1749
+
1750
+ if self.config.pad_token_id is None and batch_size != 1:
1751
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1752
+ if self.config.pad_token_id is None:
1753
+ sequence_lengths = -1
1754
+ else:
1755
+ if input_ids is not None:
1756
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1757
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1758
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1759
+ sequence_lengths = sequence_lengths.to(logits.device)
1760
+ else:
1761
+ sequence_lengths = -1
1762
+
1763
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1764
+
1765
+ loss = None
1766
+ if labels is not None:
1767
+ labels = labels.to(logits.device)
1768
+ if self.config.problem_type is None:
1769
+ if self.num_labels == 1:
1770
+ self.config.problem_type = "regression"
1771
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1772
+ self.config.problem_type = "single_label_classification"
1773
+ else:
1774
+ self.config.problem_type = "multi_label_classification"
1775
+
1776
+ if self.config.problem_type == "regression":
1777
+ loss_fct = MSELoss()
1778
+ if self.num_labels == 1:
1779
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1780
+ else:
1781
+ loss = loss_fct(pooled_logits, labels)
1782
+ elif self.config.problem_type == "single_label_classification":
1783
+ loss_fct = CrossEntropyLoss()
1784
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1785
+ elif self.config.problem_type == "multi_label_classification":
1786
+ loss_fct = BCEWithLogitsLoss()
1787
+ loss = loss_fct(pooled_logits, labels)
1788
+ if not return_dict:
1789
+ output = (pooled_logits,) + model_outputs[1:]
1790
+ return ((loss,) + output) if loss is not None else output
1791
+
1792
+ return SequenceClassifierOutputWithPast(
1793
+ loss=loss,
1794
+ logits=pooled_logits,
1795
+ past_key_values=model_outputs.past_key_values,
1796
+ hidden_states=model_outputs.hidden_states,
1797
+ attentions=model_outputs.attentions,
1798
+ )
1799
+
1800
+
1801
+ @add_start_docstrings(
1802
+ """
1803
+ [`KPhi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1804
+ Named-Entity-Recognition (NER) tasks.
1805
+ """,
1806
+ KPHI3_START_DOCSTRING,
1807
+ )
1808
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
1809
+ class KPhi3ForTokenClassification(KPhi3PreTrainedModel):
1810
+ def __init__(self, config: KPhi3Config):
1811
+ super().__init__(config)
1812
+ self.num_labels = config.num_labels
1813
+
1814
+ self.model = KPhi3Model(config)
1815
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1816
+ classifier_dropout = config.classifier_dropout
1817
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1818
+ classifier_dropout = config.hidden_dropout
1819
+ else:
1820
+ classifier_dropout = 0.1
1821
+ self.dropout = nn.Dropout(classifier_dropout)
1822
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1823
+
1824
+ # Initialize weights and apply final processing
1825
+ self.post_init()
1826
+
1827
+ @add_start_docstrings_to_model_forward(KPHI3_INPUTS_DOCSTRING)
1828
+ @add_code_sample_docstrings(
1829
+ checkpoint=_CHECKPOINT_FOR_DOC,
1830
+ output_type=TokenClassifierOutput,
1831
+ config_class=_CONFIG_FOR_DOC,
1832
+ )
1833
+ def forward(
1834
+ self,
1835
+ input_ids: Optional[torch.LongTensor] = None,
1836
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1837
+ attention_mask: Optional[torch.Tensor] = None,
1838
+ inputs_embeds: Optional[torch.Tensor] = None,
1839
+ labels: Optional[torch.Tensor] = None,
1840
+ use_cache: Optional[bool] = None,
1841
+ output_attentions: Optional[bool] = None,
1842
+ output_hidden_states: Optional[bool] = None,
1843
+ return_dict: Optional[bool] = None,
1844
+ **deprecated_arguments,
1845
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1846
+ r"""
1847
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1848
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1849
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1850
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1851
+ """
1852
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1853
+
1854
+ model_outputs = self.model(
1855
+ input_ids,
1856
+ past_key_values=past_key_values,
1857
+ attention_mask=attention_mask,
1858
+ inputs_embeds=inputs_embeds,
1859
+ use_cache=use_cache,
1860
+ output_attentions=output_attentions,
1861
+ output_hidden_states=output_hidden_states,
1862
+ return_dict=return_dict,
1863
+ )
1864
+
1865
+ hidden_states = model_outputs[0]
1866
+ hidden_states = self.dropout(hidden_states)
1867
+ logits = self.classifier(hidden_states)
1868
+
1869
+ loss = None
1870
+ if labels is not None:
1871
+ # move labels to correct device to enable model parallelism
1872
+ labels = labels.to(logits.device)
1873
+ batch_size, seq_length = labels.shape
1874
+ loss_fct = CrossEntropyLoss()
1875
+ loss = loss_fct(
1876
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1877
+ )
1878
+
1879
+ if not return_dict:
1880
+ output = (logits,) + model_outputs[2:]
1881
+ return ((loss,) + output) if loss is not None else output
1882
+
1883
+ return TokenClassifierOutput(
1884
+ loss=loss,
1885
+ logits=logits,
1886
+ hidden_states=model_outputs.hidden_states,
1887
+ attentions=model_outputs.attentions,
1888
+ )