GuoPD commited on
Commit
db8a935
1 Parent(s): 65d66b1

add: add remote code

Browse files
configuration_baichuan.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers.configuration_utils import PretrainedConfig
3
+
4
+ class BaichuanConfig(PretrainedConfig):
5
+ model_type = "baichuan"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=64000,
11
+ hidden_size=5120,
12
+ intermediate_size=13696,
13
+ num_hidden_layers=40,
14
+ num_attention_heads=40,
15
+ hidden_act="silu",
16
+ model_max_length=4096,
17
+ initializer_range=0.02,
18
+ rms_norm_eps=1e-6,
19
+ use_cache=True,
20
+ pad_token_id=0,
21
+ bos_token_id=1,
22
+ eos_token_id=2,
23
+ tie_word_embeddings=False,
24
+ **kwargs,
25
+ ):
26
+ self.vocab_size = vocab_size
27
+ self.model_max_length = model_max_length
28
+ self.hidden_size = hidden_size
29
+ self.intermediate_size = intermediate_size
30
+ self.num_hidden_layers = num_hidden_layers
31
+ self.num_attention_heads = num_attention_heads
32
+ self.hidden_act = hidden_act
33
+ self.initializer_range = initializer_range
34
+ self.rms_norm_eps = rms_norm_eps
35
+ self.use_cache = use_cache
36
+ super().__init__(
37
+ pad_token_id=pad_token_id,
38
+ bos_token_id=bos_token_id,
39
+ eos_token_id=eos_token_id,
40
+ tie_word_embeddings=tie_word_embeddings,
41
+ **kwargs,
42
+ )
43
+
modeling_baichuan.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch.nn import CrossEntropyLoss
6
+ from transformers import PreTrainedModel
7
+ from transformers.activations import ACT2FN
8
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
9
+ from transformers.utils import logging
10
+
11
+ from .configuration_baichuan import BaichuanConfig
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ def _get_slopes(n):
16
+ def _get_slopes_power_of_2(n):
17
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
18
+ ratio = start
19
+ return [start * ratio ** i for i in range(n)]
20
+
21
+ if math.log2(n).is_integer():
22
+ return _get_slopes_power_of_2(n)
23
+ else:
24
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
25
+ return _get_slopes_power_of_2(closest_power_of_2) + \
26
+ _get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
27
+
28
+ def _fill_with_neg_inf(t):
29
+ """FP16-compatible function that fills a tensor with -inf."""
30
+ return t.float().fill_(float("-inf")).type_as(t)
31
+
32
+ def _gen_alibi_mask(n_head, max_pos):
33
+ slopes = torch.Tensor(_get_slopes(n_head))
34
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(
35
+ n_head, -1, -1)
36
+ alibi = alibi.view(n_head, 1, max_pos)
37
+ alibi_mask = torch.triu(
38
+ _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
39
+ )
40
+ alibi_mask = alibi_mask.unsqueeze(0) + alibi
41
+ return alibi_mask
42
+
43
+
44
+ class RMSNorm(torch.nn.Module):
45
+ def __init__(self, hidden_size, epsilon=1e-6):
46
+ super().__init__()
47
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size))
48
+ self.epsilon = epsilon
49
+
50
+ def forward(self, hidden_states):
51
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
52
+ hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
53
+
54
+ # convert into half-precision
55
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
56
+ hidden_states = hidden_states.to(self.weight.dtype)
57
+
58
+ return self.weight * hidden_states
59
+
60
+
61
+ class MLP(torch.nn.Module):
62
+ def __init__(
63
+ self,
64
+ hidden_size: int,
65
+ intermediate_size: int,
66
+ hidden_act: str,
67
+ ):
68
+ super().__init__()
69
+ self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
70
+ self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
71
+ self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
72
+ self.act_fn = ACT2FN[hidden_act]
73
+
74
+ def forward(self, x):
75
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
76
+
77
+
78
+ class BaichuanAttention(torch.nn.Module):
79
+
80
+ def __init__(self, config: BaichuanConfig):
81
+ super().__init__()
82
+ self.config = config
83
+ self.hidden_size = config.hidden_size
84
+ self.num_heads = config.num_attention_heads
85
+ self.head_dim = self.hidden_size // self.num_heads
86
+ self.max_position_embeddings = config.model_max_length
87
+
88
+ if (self.head_dim * self.num_heads) != self.hidden_size:
89
+ raise ValueError(
90
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
91
+ f" and `num_heads`: {self.num_heads})."
92
+ )
93
+ self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
94
+ self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
95
+
96
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
97
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
98
+
99
+ def forward(
100
+ self,
101
+ hidden_states: torch.Tensor,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ position_ids: Optional[torch.LongTensor] = None,
104
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
105
+ output_attentions: bool = False,
106
+ use_cache: bool = False,
107
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
108
+
109
+ bsz, q_len, _ = hidden_states.size()
110
+
111
+ proj = self.W_pack(hidden_states)
112
+ proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
113
+ query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
114
+ key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
115
+ value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
116
+
117
+ kv_seq_len = key_states.shape[-2]
118
+ if past_key_value is not None:
119
+ kv_seq_len += past_key_value[0].shape[-2]
120
+
121
+ if past_key_value is not None:
122
+ # reuse k, v, self_attention
123
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
124
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
125
+
126
+ past_key_value = (key_states, value_states) if use_cache else None
127
+
128
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
129
+
130
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
131
+ raise ValueError(
132
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
133
+ f" {attn_weights.size()}"
134
+ )
135
+
136
+ if attention_mask is not None:
137
+ if attn_weights.size(-2) == 1:
138
+ attention_mask = attention_mask[:, -1:, :]
139
+ attn_weights = attn_weights + attention_mask.unsqueeze(0)
140
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
141
+
142
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
143
+ attn_output = torch.matmul(attn_weights, value_states)
144
+
145
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
146
+ raise ValueError(
147
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
148
+ f" {attn_output.size()}"
149
+ )
150
+
151
+ attn_output = attn_output.transpose(1, 2)
152
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
153
+ attn_output = self.o_proj(attn_output)
154
+
155
+ if not output_attentions:
156
+ attn_weights = None
157
+
158
+ return attn_output, attn_weights, past_key_value
159
+
160
+
161
+ class BaichuanLayer(torch.nn.Module):
162
+ def __init__(self, config: BaichuanConfig):
163
+ super().__init__()
164
+ self.hidden_size = config.hidden_size
165
+ self.self_attn = BaichuanAttention(config=config)
166
+ self.mlp = MLP(
167
+ hidden_size=self.hidden_size,
168
+ intermediate_size=config.intermediate_size,
169
+ hidden_act=config.hidden_act,
170
+ )
171
+ self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
172
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
173
+
174
+ def forward(
175
+ self,
176
+ hidden_states: torch.Tensor,
177
+ attention_mask: Optional[torch.Tensor] = None,
178
+ position_ids: Optional[torch.LongTensor] = None,
179
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
180
+ output_attentions: Optional[bool] = False,
181
+ use_cache: Optional[bool] = False,
182
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
183
+
184
+ residual = hidden_states
185
+
186
+ hidden_states = self.input_layernorm(hidden_states)
187
+
188
+ # Self Attention
189
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
190
+ hidden_states=hidden_states,
191
+ attention_mask=attention_mask,
192
+ position_ids=position_ids,
193
+ past_key_value=past_key_value,
194
+ output_attentions=output_attentions,
195
+ use_cache=use_cache,
196
+ )
197
+ hidden_states = residual + hidden_states
198
+
199
+ # Fully Connected
200
+ residual = hidden_states
201
+ hidden_states = self.post_attention_layernorm(hidden_states)
202
+ hidden_states = self.mlp(hidden_states)
203
+ hidden_states = residual + hidden_states
204
+
205
+ outputs = (hidden_states,)
206
+
207
+ if use_cache:
208
+ outputs += (present_key_value,)
209
+
210
+ return outputs
211
+
212
+
213
+ class BaichuanPreTrainedModel(PreTrainedModel):
214
+ config_class = BaichuanConfig
215
+ base_model_prefix = "model"
216
+ supports_gradient_checkpointing = True
217
+ _no_split_modules = ["BaichuanLayer"]
218
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
219
+
220
+ def _init_weights(self, module):
221
+ std = self.config.initializer_range
222
+ if isinstance(module, torch.nn.Linear):
223
+ module.weight.data.normal_(mean=0.0, std=std)
224
+ if module.bias is not None:
225
+ module.bias.data.zero_()
226
+ elif isinstance(module, torch.nn.Embedding):
227
+ module.weight.data.normal_(mean=0.0, std=std)
228
+ if module.padding_idx is not None:
229
+ module.weight.data[module.padding_idx].zero_()
230
+
231
+ def _set_gradient_checkpointing(self, module, value=False):
232
+ if isinstance(module, BaichuanModel):
233
+ module.gradient_checkpointing = value
234
+
235
+
236
+
237
+ class BaichuanModel(BaichuanPreTrainedModel):
238
+ def __init__(self, config: BaichuanConfig):
239
+ super().__init__(config)
240
+ self.padding_idx = config.pad_token_id
241
+ self.vocab_size = config.vocab_size
242
+ self.n_head = config.num_attention_heads
243
+ self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
244
+ self.layers = torch.nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])
245
+ self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
246
+
247
+ self.gradient_checkpointing = False
248
+ self.post_init()
249
+ self.max_cache_pos = config.model_max_length
250
+ self.first_run = True
251
+
252
+ def get_alibi_mask(self, tensor, seq_length_with_past):
253
+ if self.first_run:
254
+ self.first_run = False
255
+ self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
256
+ if (seq_length_with_past > self.max_cache_pos):
257
+ self.max_cache_pos = seq_length_with_past
258
+ self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
259
+ mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
260
+ return mask
261
+
262
+ def forward(
263
+ self,
264
+ input_ids: torch.LongTensor = None,
265
+ attention_mask: Optional[torch.Tensor] = None,
266
+ position_ids: Optional[torch.LongTensor] = None,
267
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
268
+ inputs_embeds: Optional[torch.FloatTensor] = None,
269
+ use_cache: Optional[bool] = False,
270
+ output_attentions: Optional[bool] = False,
271
+ output_hidden_states: Optional[bool] = False,
272
+ return_dict: Optional[bool] = True,
273
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
274
+
275
+
276
+ # retrieve input_ids and inputs_embeds
277
+ if input_ids is not None and inputs_embeds is not None:
278
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
279
+ elif input_ids is not None:
280
+ batch_size, seq_length = input_ids.shape
281
+ elif inputs_embeds is not None:
282
+ batch_size, seq_length, _ = inputs_embeds.shape
283
+ else:
284
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
285
+
286
+ seq_length_with_past = seq_length
287
+ past_key_values_length = 0
288
+
289
+ if past_key_values is not None:
290
+ past_key_values_length = past_key_values[0][0].shape[2]
291
+ seq_length_with_past = seq_length_with_past + past_key_values_length
292
+
293
+ if position_ids is None:
294
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
295
+ position_ids = torch.arange(
296
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
297
+ )
298
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
299
+ else:
300
+ position_ids = position_ids.view(-1, seq_length).long()
301
+
302
+ if inputs_embeds is None:
303
+ inputs_embeds = self.embed_tokens(input_ids)
304
+ # embed positions
305
+ attention_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
306
+
307
+ hidden_states = inputs_embeds
308
+
309
+ if self.gradient_checkpointing and self.training:
310
+ if use_cache:
311
+ logger.warning_once(
312
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
313
+ )
314
+ use_cache = False
315
+
316
+ # decoder layers
317
+ all_hidden_states = () if output_hidden_states else None
318
+ all_self_attns = () if output_attentions else None
319
+ next_decoder_cache = () if use_cache else None
320
+
321
+ for idx, decoder_layer in enumerate(self.layers):
322
+ if output_hidden_states:
323
+ all_hidden_states += (hidden_states,)
324
+
325
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
326
+
327
+ if self.gradient_checkpointing and self.training:
328
+
329
+ def create_custom_forward(module):
330
+ def custom_forward(*inputs):
331
+ # None for past_key_value
332
+ return module(*inputs, output_attentions, None)
333
+
334
+ return custom_forward
335
+
336
+ layer_outputs = torch.utils.checkpoint.checkpoint(
337
+ create_custom_forward(decoder_layer),
338
+ hidden_states,
339
+ attention_mask,
340
+ position_ids,
341
+ None,
342
+ )
343
+ else:
344
+ layer_outputs = decoder_layer(
345
+ hidden_states,
346
+ attention_mask=attention_mask,
347
+ position_ids=position_ids,
348
+ past_key_value=past_key_value,
349
+ output_attentions=output_attentions,
350
+ use_cache=use_cache,
351
+ )
352
+
353
+ hidden_states = layer_outputs[0]
354
+
355
+ if use_cache:
356
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
357
+
358
+ if output_attentions:
359
+ all_self_attns += (layer_outputs[1],)
360
+
361
+ hidden_states = self.norm(hidden_states)
362
+
363
+ # add hidden states from the last decoder layer
364
+ if output_hidden_states:
365
+ all_hidden_states += (hidden_states,)
366
+
367
+ next_cache = next_decoder_cache if use_cache else None
368
+ if not return_dict:
369
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
370
+ return BaseModelOutputWithPast(
371
+ last_hidden_state=hidden_states,
372
+ past_key_values=next_cache,
373
+ hidden_states=all_hidden_states,
374
+ attentions=all_self_attns,
375
+ )
376
+
377
+
378
+ class BaichuanForCausalLM(BaichuanPreTrainedModel):
379
+ def __init__(self, config):
380
+ super().__init__(config)
381
+ self.model = BaichuanModel(config)
382
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
383
+
384
+ # Initialize weights and apply final processing
385
+ self.post_init()
386
+
387
+ def forward(
388
+ self,
389
+ input_ids: torch.LongTensor = None,
390
+ attention_mask: Optional[torch.Tensor] = None,
391
+ position_ids: Optional[torch.LongTensor] = None,
392
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
393
+ inputs_embeds: Optional[torch.FloatTensor] = None,
394
+ labels: Optional[torch.LongTensor] = None,
395
+ use_cache: Optional[bool] = None,
396
+ output_attentions: Optional[bool] = False,
397
+ output_hidden_states: Optional[bool] = False,
398
+ return_dict: Optional[bool] = True,
399
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
400
+
401
+
402
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
403
+ outputs = self.model(
404
+ input_ids=input_ids,
405
+ attention_mask=attention_mask,
406
+ position_ids=position_ids,
407
+ past_key_values=past_key_values,
408
+ inputs_embeds=inputs_embeds,
409
+ use_cache=use_cache,
410
+ output_attentions=output_attentions,
411
+ output_hidden_states=output_hidden_states,
412
+ return_dict=return_dict,
413
+ )
414
+
415
+ hidden_states = outputs[0]
416
+ logits = self.lm_head(hidden_states)
417
+
418
+ loss = None
419
+ if labels is not None:
420
+ # Shift so that tokens < n predict n
421
+ shift_logits = logits[..., :-1, :].contiguous()
422
+ shift_labels = labels[..., 1:].contiguous()
423
+ # Flatten the tokens
424
+ loss_fct = CrossEntropyLoss()
425
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
426
+ shift_labels = shift_labels.view(-1)
427
+ # Enable model parallelism
428
+ shift_labels = shift_labels.to(shift_logits.device)
429
+ loss = loss_fct(shift_logits, shift_labels)
430
+
431
+ if not return_dict:
432
+ output = (logits,) + outputs[1:]
433
+ return (loss,) + output if loss is not None else output
434
+
435
+ return CausalLMOutputWithPast(
436
+ loss=loss,
437
+ logits=logits,
438
+ past_key_values=outputs.past_key_values,
439
+ hidden_states=outputs.hidden_states,
440
+ attentions=outputs.attentions,
441
+ )
442
+
443
+ def prepare_inputs_for_generation(
444
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
445
+ ):
446
+ if past_key_values:
447
+ input_ids = input_ids[:, -1:]
448
+
449
+ position_ids = kwargs.get("position_ids", None)
450
+ if attention_mask is not None and position_ids is None:
451
+ # create position_ids on the fly for batch generation
452
+ position_ids = attention_mask.long().cumsum(-1) - 1
453
+ position_ids.masked_fill_(attention_mask == 0, 1)
454
+ if past_key_values:
455
+ position_ids = position_ids[:, -1].unsqueeze(-1)
456
+
457
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
458
+ if inputs_embeds is not None and past_key_values is None:
459
+ model_inputs = {"inputs_embeds": inputs_embeds}
460
+ else:
461
+ model_inputs = {"input_ids": input_ids}
462
+
463
+ model_inputs.update(
464
+ {
465
+ "position_ids": position_ids,
466
+ "past_key_values": past_key_values,
467
+ "use_cache": kwargs.get("use_cache"),
468
+ "attention_mask": attention_mask,
469
+ }
470
+ )
471
+ return model_inputs
472
+
473
+ @staticmethod
474
+ def _reorder_cache(past_key_values, beam_idx):
475
+ return tuple(
476
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
477
+ for layer_past in past_key_values
478
+ )
479
+
480
+ def quantize(self, bits: int):
481
+ try:
482
+ from .quantizer import QLinear
483
+ except ImportError:
484
+ raise ImportError(
485
+ f"Error: Needs QLinear to run quantize."
486
+ )
487
+
488
+ for layer in self.model.layers:
489
+ layer.self_attn.W_pack = QLinear(
490
+ bits=bits,
491
+ weight=layer.self_attn.W_pack.weight,
492
+ bias = None,
493
+ )
494
+ layer.self_attn.o_proj = QLinear(
495
+ bits=bits,
496
+ weight=layer.self_attn.o_proj.weight,
497
+ bias = None,
498
+ )
499
+ layer.mlp.gate_proj = QLinear(
500
+ bits=bits,
501
+ weight=layer.mlp.gate_proj.weight,
502
+ bias = None,
503
+ )
504
+ layer.mlp.down_proj = QLinear(
505
+ bits=bits,
506
+ weight=layer.mlp.down_proj.weight,
507
+ bias = None,
508
+ )
509
+ layer.mlp.up_proj = QLinear(
510
+ bits=bits,
511
+ weight=layer.mlp.up_proj.weight,
512
+ bias = None,
513
+ )
514
+ return self
quantizer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class QLinear(torch.nn.Module):
4
+ def __init__(self, bits: int, weight: torch.Tensor, bias=None):
5
+ super().__init__()
6
+ self.quant_bits = bits
7
+ if self.quant_bits != 8:
8
+ raise ValueError(
9
+ f'Only supprt int8 quant in current version'
10
+ )
11
+ self.scale = weight.abs().max(dim=-1).values / ((2 ** (bits - 1)) - 1)
12
+ self.weight = torch.round(weight / self.scale[:, None]).to(torch.int8)
13
+ self.weight = self.weight.T
14
+ self.bias = None
15
+
16
+ def forward(self, input):
17
+ if self.weight.device != input.device:
18
+ self.weight = self.weight.to(input.device)
19
+ self.scale = self.scale.to(input.device)
20
+
21
+ output = torch.matmul(input, self.weight.to(input.dtype)) * self.scale.to(input.dtype)[None,None, :]
22
+ if self.bias is not None:
23
+ output = output + self.bias
24
+ return output
tokenization_baichuan.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import sentencepiece as spm
6
+
7
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
8
+ from transformers.utils import logging
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
14
+
15
+ PRETRAINED_VOCAB_FILES_MAP = {
16
+ "vocab_file": {},
17
+ "tokenizer_file": {},
18
+ }
19
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
20
+
21
+
22
+ class BaichuanTokenizer(PreTrainedTokenizer):
23
+ """
24
+ Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding.
25
+
26
+ Args:
27
+ vocab_file (`str`):
28
+ Path to the vocabulary file.
29
+ """
30
+
31
+ vocab_files_names = VOCAB_FILES_NAMES
32
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
33
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
34
+ model_input_names = ["input_ids", "attention_mask"]
35
+
36
+ def __init__(
37
+ self,
38
+ vocab_file,
39
+ unk_token="<unk>",
40
+ bos_token="<s>",
41
+ eos_token="</s>",
42
+ pad_token=None,
43
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
44
+ add_bos_token=True,
45
+ add_eos_token=False,
46
+ clean_up_tokenization_spaces=False,
47
+ **kwargs,
48
+ ):
49
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
50
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
51
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
52
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
53
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
54
+ super().__init__(
55
+ bos_token=bos_token,
56
+ eos_token=eos_token,
57
+ unk_token=unk_token,
58
+ pad_token=pad_token,
59
+ add_bos_token=add_bos_token,
60
+ add_eos_token=add_eos_token,
61
+ sp_model_kwargs=self.sp_model_kwargs,
62
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
63
+ **kwargs,
64
+ )
65
+ self.vocab_file = vocab_file
66
+ self.add_bos_token = add_bos_token
67
+ self.add_eos_token = add_eos_token
68
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
69
+ self.sp_model.Load(vocab_file)
70
+
71
+ def __getstate__(self):
72
+ state = self.__dict__.copy()
73
+ state["sp_model"] = None
74
+ return state
75
+
76
+ def __setstate__(self, d):
77
+ self.__dict__ = d
78
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
79
+ self.sp_model.Load(self.vocab_file)
80
+
81
+ @property
82
+ def vocab_size(self):
83
+ """Returns vocab size"""
84
+ return self.sp_model.get_piece_size()
85
+
86
+ def get_vocab(self):
87
+ """Returns vocab as a dict"""
88
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
89
+ vocab.update(self.added_tokens_encoder)
90
+ return vocab
91
+
92
+ def _tokenize(self, text):
93
+ """Returns a tokenized string."""
94
+ return self.sp_model.encode(text, out_type=str)
95
+
96
+ def _convert_token_to_id(self, token):
97
+ """Converts a token (str) in an id using the vocab."""
98
+ return self.sp_model.piece_to_id(token)
99
+
100
+ def _convert_id_to_token(self, index):
101
+ """Converts an index (integer) in a token (str) using the vocab."""
102
+ token = self.sp_model.IdToPiece(index)
103
+ return token
104
+
105
+ def convert_tokens_to_string(self, tokens):
106
+ """Converts a sequence of tokens (string) in a single string."""
107
+ current_sub_tokens = []
108
+ out_string = ""
109
+ prev_is_special = False
110
+ for i, token in enumerate(tokens):
111
+ # make sure that special tokens are not decoded using sentencepiece model
112
+ if token in self.all_special_tokens:
113
+ if not prev_is_special and i != 0:
114
+ out_string += " "
115
+ out_string += self.sp_model.decode(current_sub_tokens) + token
116
+ prev_is_special = True
117
+ current_sub_tokens = []
118
+ else:
119
+ current_sub_tokens.append(token)
120
+ prev_is_special = False
121
+ out_string += self.sp_model.decode(current_sub_tokens)
122
+ return out_string
123
+
124
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
125
+ """
126
+ Save the vocabulary and special tokens file to a directory.
127
+
128
+ Args:
129
+ save_directory (`str`):
130
+ The directory in which to save the vocabulary.
131
+
132
+ Returns:
133
+ `Tuple(str)`: Paths to the files saved.
134
+ """
135
+ if not os.path.isdir(save_directory):
136
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
137
+ return
138
+ out_vocab_file = os.path.join(
139
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
140
+ )
141
+
142
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
143
+ copyfile(self.vocab_file, out_vocab_file)
144
+ elif not os.path.isfile(self.vocab_file):
145
+ with open(out_vocab_file, "wb") as fi:
146
+ content_spiece_model = self.sp_model.serialized_model_proto()
147
+ fi.write(content_spiece_model)
148
+
149
+ return (out_vocab_file,)
150
+
151
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
152
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
153
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
154
+
155
+ output = bos_token_id + token_ids_0 + eos_token_id
156
+
157
+ if token_ids_1 is not None:
158
+ output = output + bos_token_id + token_ids_1 + eos_token_id
159
+
160
+ return output
161
+
162
+ def get_special_tokens_mask(
163
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
164
+ ) -> List[int]:
165
+ """
166
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
167
+ special tokens using the tokenizer `prepare_for_model` method.
168
+
169
+ Args:
170
+ token_ids_0 (`List[int]`):
171
+ List of IDs.
172
+ token_ids_1 (`List[int]`, *optional*):
173
+ Optional second list of IDs for sequence pairs.
174
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
175
+ Whether or not the token list is already formatted with special tokens for the model.
176
+
177
+ Returns:
178
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
179
+ """
180
+ if already_has_special_tokens:
181
+ return super().get_special_tokens_mask(
182
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
183
+ )
184
+
185
+ bos_token_id = [1] if self.add_bos_token else []
186
+ eos_token_id = [1] if self.add_eos_token else []
187
+
188
+ if token_ids_1 is None:
189
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
190
+ return (
191
+ bos_token_id
192
+ + ([0] * len(token_ids_0))
193
+ + eos_token_id
194
+ + bos_token_id
195
+ + ([0] * len(token_ids_1))
196
+ + eos_token_id
197
+ )
198
+
199
+ def create_token_type_ids_from_sequences(
200
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
201
+ ) -> List[int]:
202
+ """
203
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
204
+ sequence pair mask has the following format:
205
+
206
+ ```
207
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
208
+ | first sequence | second sequence |
209
+ ```
210
+
211
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
212
+
213
+ Args:
214
+ token_ids_0 (`List[int]`):
215
+ List of ids.
216
+ token_ids_1 (`List[int]`, *optional*):
217
+ Optional second list of IDs for sequence pairs.
218
+
219
+ Returns:
220
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
221
+ """
222
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
223
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
224
+
225
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
226
+
227
+ if token_ids_1 is not None:
228
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
229
+
230
+ return output
231
+