lhallee commited on
Commit
9117f48
·
verified ·
1 Parent(s): fbb9df7

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +636 -635
modeling_esm_plusplus.py CHANGED
@@ -1,635 +1,636 @@
1
- ### Modified from https://github.com/evolutionaryscale/esm
2
- ### License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- import math
7
- from dataclasses import dataclass
8
- from transformers import PreTrainedModel, PretrainedConfig
9
- from einops import rearrange, repeat
10
- from functools import partial
11
- from typing import Optional, Tuple
12
- from transformers.modeling_outputs import ModelOutput
13
-
14
-
15
- class ESMplusplusConfig(PretrainedConfig):
16
- model_type = "ESMplusplus"
17
- def __init__(
18
- self,
19
- vocab_size: int = 64,
20
- hidden_size: int = 960,
21
- num_attention_heads: int = 15,
22
- num_hidden_layers: int = 30,
23
- num_labels: int = 2,
24
- problem_type: str | None = None,
25
- **kwargs,
26
- ):
27
- super().__init__(**kwargs)
28
- self.vocab_size = vocab_size
29
- self.hidden_size = hidden_size
30
- self.num_attention_heads = num_attention_heads
31
- self.num_hidden_layers = num_hidden_layers
32
- self.num_labels = num_labels
33
- self.problem_type = problem_type
34
-
35
-
36
- ### Rotary
37
- # https://github.com/evolutionaryscale/esm/blob/main/esm/layers/rotary.py
38
- # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114
39
- # Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`
40
- def rotate_half(x, interleaved=False):
41
- if not interleaved:
42
- x1, x2 = x.chunk(2, dim=-1)
43
- return torch.cat((-x2, x1), dim=-1)
44
- else:
45
- x1, x2 = x[..., ::2], x[..., 1::2]
46
- return rearrange(
47
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
48
- )
49
-
50
-
51
- def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False):
52
- """
53
- x: (batch_size, seqlen, nheads, headdim)
54
- cos, sin: (seqlen, rotary_dim / 2)
55
- """
56
- ro_dim = cos.shape[-1] * 2
57
- assert ro_dim <= x.shape[-1]
58
- seqlen = x.size(1)
59
- cos = cos[:seqlen]
60
- sin = sin[:seqlen]
61
- cos = repeat(cos, "s d -> s 1 (2 d)")
62
- sin = repeat(sin, "s d -> s 1 (2 d)")
63
- return torch.cat(
64
- [
65
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
66
- x[..., ro_dim:],
67
- ],
68
- dim=-1,
69
- )
70
-
71
-
72
- class RotaryEmbedding(torch.nn.Module):
73
- def __init__(
74
- self,
75
- dim: int,
76
- base=10000.0,
77
- interleaved=False,
78
- scale_base=None,
79
- scaling_factor=1.0,
80
- pos_idx_in_fp32=True,
81
- device=None,
82
- ):
83
- super().__init__()
84
- self.dim = dim
85
- self.base = float(base)
86
- self.pos_idx_in_fp32 = pos_idx_in_fp32
87
- # Generate and save the inverse frequency buffer (non trainable)
88
- self.interleaved = interleaved
89
- self.scale_base = scale_base
90
- self.scaling_factor = scaling_factor
91
- self.device = device
92
-
93
- self._seq_len_cached = 0
94
- self._cos_cached = None
95
- self._sin_cached = None
96
- self._cos_k_cached = None
97
- self._sin_k_cached = None
98
- self.reset_parameters()
99
-
100
- def reset_parameters(self):
101
- inv_freq = self._compute_inv_freq(self.device)
102
- self.register_buffer("inv_freq", inv_freq, persistent=False)
103
- arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
104
- scale = (
105
- (arange + 0.4 * self.dim) / (1.4 * self.dim)
106
- if self.scale_base is not None
107
- else None
108
- )
109
- self.register_buffer("scale", scale)
110
-
111
- def _compute_inv_freq(self, device=None):
112
- return 1 / (
113
- self.base
114
- ** (
115
- torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
116
- / self.dim
117
- )
118
- )
119
-
120
- def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
121
- if (
122
- seqlen > self._seq_len_cached
123
- or self._cos_cached is None
124
- or self._cos_cached.device != device
125
- or self._cos_cached.dtype != dtype
126
- or (self.training and self._cos_cached.is_inference())
127
- ):
128
- self._seq_len_cached = seqlen
129
- if self.pos_idx_in_fp32:
130
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
131
- t /= self.scaling_factor
132
- if self.inv_freq.dtype != torch.float32:
133
- inv_freq = self.inv_freq.to(torch.float32)
134
- else:
135
- inv_freq = self.inv_freq
136
- else:
137
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
138
- t /= self.scaling_factor
139
- inv_freq = self.inv_freq
140
- freqs = torch.outer(t, inv_freq)
141
-
142
- if self.scale is None:
143
- self._cos_cached = torch.cos(freqs).to(dtype)
144
- self._sin_cached = torch.sin(freqs).to(dtype)
145
- else:
146
- power = (
147
- torch.arange(
148
- seqlen, dtype=self.scale.dtype, device=self.scale.device
149
- )
150
- - seqlen // 2
151
- ) / self.scale_base
152
- scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
153
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
154
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
155
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
156
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
157
-
158
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
159
- """
160
- q: (batch, seqlen, nheads, headdim)
161
- k: (batch, seqlen, nheads, headdim)
162
- """
163
- self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
164
- assert self._cos_cached is not None
165
- assert self._sin_cached is not None
166
- if self.scale is None:
167
- return (
168
- apply_rotary_emb_torch(
169
- q,
170
- self._cos_cached,
171
- self._sin_cached,
172
- self.interleaved,
173
- True, # inplace=True
174
- ),
175
- apply_rotary_emb_torch(
176
- k,
177
- self._cos_cached,
178
- self._sin_cached,
179
- self.interleaved,
180
- True, # inplace=True
181
- ),
182
- ) # type: ignore
183
- else:
184
- assert False
185
-
186
-
187
- ### Feedforward
188
- def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
189
- return int(((expansion_ratio * d_model) + 255) // 256 * 256)
190
-
191
-
192
- class SwiGLU(nn.Module):
193
- def __init__(self):
194
- super(SwiGLU, self).__init__()
195
-
196
- def forward(self, x: torch.Tensor) -> torch.Tensor:
197
- x1, x2 = x.chunk(2, dim=-1)
198
- return F.silu(x1) * x2
199
-
200
-
201
- def swiglu_ln_ffn(d_model: int, expansion_ratio: float):
202
- return nn.Sequential(
203
- nn.LayerNorm(d_model),
204
- nn.Linear(
205
- d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
206
- ),
207
- SwiGLU(),
208
- nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
209
- )
210
-
211
-
212
- ### Attention
213
- class MultiHeadAttention(nn.Module):
214
- def __init__(self, d_model: int, n_heads: int):
215
- super().__init__()
216
- self.d_model = d_model
217
- self.n_heads = n_heads
218
- self.d_head = self.d_model // self.n_heads
219
- self.layernorm_qkv = nn.Sequential(
220
- nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
221
- )
222
- self.out_proj = nn.Linear(d_model, d_model, bias=False)
223
- self.q_ln = nn.LayerNorm(d_model, bias=False)
224
- self.k_ln = nn.LayerNorm(d_model, bias=False)
225
- self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
226
- self.rotary = RotaryEmbedding(d_model // n_heads)
227
-
228
- def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
229
- q = q.unflatten(-1, (self.n_heads, self.d_head))
230
- k = k.unflatten(-1, (self.n_heads, self.d_head))
231
- q, k = self.rotary(q, k)
232
- q = q.flatten(-2, -1)
233
- k = k.flatten(-2, -1)
234
- return q, k
235
-
236
- def forward(self, x, attention_mask=None):
237
- qkv_BLD3 = self.layernorm_qkv(x)
238
- query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
239
- query_BLD, key_BLD = (
240
- self.q_ln(query_BLD).to(query_BLD.dtype),
241
- self.k_ln(key_BLD).to(query_BLD.dtype),
242
- )
243
- query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
244
- query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
245
- context_BHLD = F.scaled_dot_product_attention(
246
- query_BHLD, key_BHLD, value_BHLD, attention_mask
247
- )
248
- context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
249
- return self.out_proj(context_BLD)
250
-
251
-
252
- ### LM Head
253
- def RegressionHead(
254
- d_model: int, output_dim: int, hidden_dim: int | None = None
255
- ) -> nn.Module:
256
- hidden_dim = hidden_dim if hidden_dim is not None else d_model
257
- return nn.Sequential(
258
- nn.Linear(d_model, hidden_dim),
259
- nn.GELU(),
260
- nn.LayerNorm(hidden_dim),
261
- nn.Linear(hidden_dim, output_dim),
262
- )
263
-
264
-
265
- ### Transformer Block
266
- class UnifiedTransformerBlock(nn.Module):
267
- def __init__(
268
- self,
269
- d_model: int,
270
- n_heads: int,
271
- residue_scaling_factor: float = 1,
272
- expansion_ratio: float = 8 / 3,
273
- ):
274
- super().__init__()
275
- self.attn = MultiHeadAttention(d_model, n_heads)
276
- self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
277
- self.scaling_factor = residue_scaling_factor
278
-
279
- def forward(
280
- self,
281
- x: torch.Tensor,
282
- attention_mask: Optional[torch.Tensor] = None,
283
- ) -> torch.Tensor:
284
- r1 = self.attn(x, attention_mask)
285
- x = x + r1 / self.scaling_factor
286
- r3 = self.ffn(x) / self.scaling_factor
287
- x = x + r3
288
- return x
289
-
290
-
291
- ### Outputs
292
- @dataclass
293
- class TransformerOutput(ModelOutput):
294
- last_hidden_state: torch.Tensor | None = None
295
- hidden_states: tuple[torch.Tensor] | None = None
296
-
297
-
298
- @dataclass
299
- class ESMplusplusOutput(ModelOutput):
300
- loss: torch.Tensor | None = None
301
- logits: torch.Tensor | None = None
302
- last_hidden_state: torch.Tensor | None = None
303
- hidden_states: tuple[torch.Tensor] | None = None
304
-
305
-
306
- ### Transformer
307
- class TransformerStack(nn.Module):
308
- def __init__(
309
- self,
310
- d_model: int,
311
- n_heads: int,
312
- n_layers: int,
313
- ):
314
- super().__init__()
315
- self.blocks = nn.ModuleList(
316
- [
317
- UnifiedTransformerBlock(
318
- d_model,
319
- n_heads,
320
- residue_scaling_factor=math.sqrt(n_layers / 36),
321
- )
322
- for i in range(n_layers)
323
- ]
324
- )
325
- self.norm = nn.LayerNorm(d_model, bias=False)
326
-
327
- def forward(
328
- self,
329
- x: torch.Tensor,
330
- attention_mask: Optional[torch.Tensor] = None,
331
- output_hidden_states: bool = False,
332
- ) -> TransformerOutput:
333
- batch_size, seq_len, _ = x.shape
334
- hidden_states = ()
335
- if attention_mask is not None:
336
- attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
337
- for block in self.blocks:
338
- x = block(x, attention_mask)
339
- if output_hidden_states:
340
- hidden_states += (x,)
341
- return TransformerOutput(last_hidden_state=self.norm(x), hidden_states=hidden_states)
342
-
343
-
344
- ### Full model
345
- class ESMplusplusForMaskedLM(PreTrainedModel):
346
- """
347
- ESM++ for masked language modeling.
348
- """
349
- config_class = ESMplusplusConfig
350
- def __init__(self, config: ESMplusplusConfig):
351
- super().__init__(config)
352
- self.config = config
353
- self.vocab_size = config.vocab_size
354
- self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
355
- self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers)
356
- self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
357
- self.ce_loss = nn.CrossEntropyLoss()
358
- self.tokenizer = EsmSequenceTokenizer()
359
-
360
- @classmethod
361
- def from_pretrained_esm(cls, model_name: str):
362
- if '300' in model_name:
363
- return ESMplusplus_300M()
364
- elif '600' in model_name:
365
- return ESMplusplus_600M()
366
- else:
367
- raise ValueError(f"Invalid model name: {model_name}")
368
-
369
- @property
370
- def device(self):
371
- return next(self.parameters()).device
372
-
373
- def forward(
374
- self,
375
- input_ids: torch.Tensor | None = None,
376
- attention_mask: Optional[torch.Tensor] = None,
377
- labels: Optional[torch.Tensor] = None,
378
- output_hidden_states: bool = False,
379
- ) -> ESMplusplusOutput:
380
- x = self.embed(input_ids)
381
- output = self.transformer(x, attention_mask, output_hidden_states)
382
- x = output.last_hidden_state
383
- logits = self.sequence_head(x)
384
- loss = None
385
- if labels is not None:
386
- loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
387
- return ESMplusplusOutput(
388
- loss=loss,
389
- logits=logits,
390
- last_hidden_state=x,
391
- hidden_states=output.hidden_states,
392
- )
393
-
394
-
395
- class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
396
- """
397
- ESM++ for sequence classification.
398
- """
399
- def __init__(self, config: ESMplusplusConfig):
400
- super().__init__(config)
401
- self.config = config
402
- self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
403
- # we find that large intermediate projections help with sequence classification tasks (*4)
404
- self.mse = nn.MSELoss()
405
- self.ce = nn.CrossEntropyLoss()
406
- self.bce = nn.BCEWithLogitsLoss()
407
-
408
- def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
409
- # x: (batch_size, seq_len, hidden_size)
410
- # attention_mask: (batch_size, seq_len)
411
- if attention_mask is None:
412
- return x.mean(dim=1)
413
- else:
414
- return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
415
-
416
- def forward(
417
- self,
418
- input_ids: torch.Tensor | None = None,
419
- attention_mask: Optional[torch.Tensor] = None,
420
- labels: Optional[torch.Tensor] = None,
421
- output_hidden_states: bool = False,
422
- ) -> ESMplusplusOutput:
423
- output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
424
- x = output.last_hidden_state
425
- cls_features = x[:, 0, :]
426
- mean_features = self.mean_pooling(x, attention_mask)
427
- # we include mean pooling features to help with early convergence, the cost of this is basically zero
428
- features = torch.cat([cls_features, mean_features], dim=-1)
429
- logits = self.classifier(features)
430
- loss = None
431
- if labels is not None:
432
- labels = labels.to(logits.device)
433
- if self.config.problem_type is None:
434
- if self.num_labels == 1:
435
- self.config.problem_type = "regression"
436
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
437
- self.config.problem_type = "single_label_classification"
438
- else:
439
- self.config.problem_type = "multi_label_classification"
440
-
441
- if self.config.problem_type == "regression":
442
- if self.num_labels == 1:
443
- loss = self.mse(logits.squeeze(), labels.squeeze())
444
- else:
445
- loss = self.mse(logits, labels)
446
- elif self.config.problem_type == "single_label_classification":
447
- loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
448
- elif self.config.problem_type == "multi_label_classification":
449
- loss = self.bce(logits, labels)
450
- return ESMplusplusOutput(
451
- loss=loss,
452
- logits=logits,
453
- last_hidden_state=x,
454
- hidden_states=output.hidden_states,
455
- )
456
-
457
-
458
- class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
459
- """
460
- ESM++ for token classification.
461
- """
462
- def __init__(self, config: ESMplusplusConfig):
463
- super().__init__(config)
464
- self.config = config
465
- self.num_labels = config.num_labels
466
- self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
467
- # we find that large intermediate projections help with sequence classification tasks (*4)
468
- self.loss_fct = nn.CrossEntropyLoss()
469
-
470
- def forward(
471
- self,
472
- input_ids: torch.Tensor | None = None,
473
- attention_mask: Optional[torch.Tensor] = None,
474
- labels: Optional[torch.Tensor] = None,
475
- output_hidden_states: bool = False,
476
- ) -> ESMplusplusOutput:
477
- output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
478
- x = output.last_hidden_state
479
- logits = self.classifier(x)
480
- loss = None
481
- if labels is not None:
482
- loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
483
- return ESMplusplusOutput(
484
- loss=loss,
485
- logits=logits,
486
- last_hidden_state=x,
487
- hidden_states=output.hidden_states,
488
- )
489
-
490
-
491
- ### Loading
492
- import os
493
- from functools import cache
494
- from pathlib import Path
495
- from huggingface_hub import snapshot_download
496
-
497
-
498
- @staticmethod
499
- @cache
500
- def data_root(model: str):
501
- if "INFRA_PROVIDER" in os.environ:
502
- return Path("")
503
- # Try to download from hugginface if it doesn't exist
504
- if model.startswith("esmc-300"):
505
- path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
506
- elif model.startswith("esmc-600"):
507
- path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
508
- else:
509
- raise ValueError(f"{model=} is an invalid model name.")
510
- return path
511
-
512
-
513
- def ESMplusplus_300M(device: torch.device | str = "cpu"):
514
- with torch.device(device):
515
- config = ESMplusplusConfig(
516
- hidden_size=960,
517
- num_attention_heads=15,
518
- num_hidden_layers=30,
519
- )
520
- model = ESMplusplusForMaskedLM(config)
521
- state_dict = torch.load(
522
- data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
523
- map_location=device,
524
- )
525
- model.load_state_dict(state_dict)
526
- return model
527
-
528
-
529
- def ESMplusplus_600M(device: torch.device | str = "cpu"):
530
- with torch.device(device):
531
- config = ESMplusplusConfig(
532
- hidden_size=1152,
533
- num_attention_heads=18,
534
- num_hidden_layers=36,
535
- )
536
- model = ESMplusplusForMaskedLM(config)
537
- state_dict = torch.load(
538
- data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
539
- map_location=device,
540
- )
541
- model.load_state_dict(state_dict)
542
- return model
543
-
544
-
545
- ### Tokenization
546
- from tokenizers import Tokenizer
547
- from tokenizers.models import BPE
548
- from tokenizers.processors import TemplateProcessing
549
- from transformers import PreTrainedTokenizerFast
550
-
551
-
552
- SEQUENCE_VOCAB = [
553
- "<cls>", "<pad>", "<eos>", "<unk>",
554
- "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
555
- "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
556
- "O", ".", "-", "|",
557
- "<mask>",
558
- ]
559
-
560
- class EsmSequenceTokenizer(PreTrainedTokenizerFast):
561
- model_input_names = ["input_ids", "attention_mask"]
562
-
563
- def __init__(
564
- self,
565
- unk_token="<unk>",
566
- cls_token="<cls>",
567
- pad_token="<pad>",
568
- mask_token="<mask>",
569
- eos_token="<eos>",
570
- chain_break_token="|",
571
- **kwargs,
572
- ):
573
- all_tokens = SEQUENCE_VOCAB
574
- token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
575
-
576
- # a character-level tokenizer is the same as BPE with no token merges
577
- bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
578
- tokenizer = Tokenizer(bpe)
579
- special_tokens = [
580
- cls_token,
581
- pad_token,
582
- mask_token,
583
- eos_token,
584
- chain_break_token,
585
- ]
586
- self.cb_token = chain_break_token
587
- additional_special_tokens = [chain_break_token]
588
-
589
- tokenizer.add_special_tokens(special_tokens)
590
-
591
- # This is where we configure the automatic addition of special tokens when we call
592
- # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
593
- # sequences are merged if you want.
594
- tokenizer.post_processor = TemplateProcessing( # type: ignore
595
- single="<cls> $A <eos>",
596
- special_tokens=[
597
- ("<cls>", tokenizer.token_to_id("<cls>")),
598
- ("<eos>", tokenizer.token_to_id("<eos>")),
599
- ],
600
- )
601
- super().__init__(
602
- tokenizer_object=tokenizer,
603
- unk_token=unk_token,
604
- cls_token=cls_token,
605
- pad_token=pad_token,
606
- mask_token=mask_token,
607
- eos_token=eos_token,
608
- additional_special_tokens=additional_special_tokens,
609
- **kwargs,
610
- )
611
-
612
- # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
613
- @property
614
- def bos_token(self):
615
- return self.cls_token
616
-
617
- @property
618
- def bos_token_id(self):
619
- return self.cls_token_id
620
-
621
- @property
622
- def chain_break_token(self):
623
- return self.cb_token
624
-
625
- @property
626
- def chain_break_token_id(self):
627
- return self.convert_tokens_to_ids(self.chain_break_token)
628
-
629
- @property
630
- def all_token_ids(self):
631
- return list(range(self.vocab_size))
632
-
633
- @property
634
- def special_token_ids(self):
635
- return self.all_special_ids
 
 
1
+ ### Modified from https://github.com/evolutionaryscale/esm
2
+ ### License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from dataclasses import dataclass
8
+ from transformers import PreTrainedModel, PretrainedConfig
9
+ from einops import rearrange, repeat
10
+ from functools import partial
11
+ from typing import Optional, Tuple
12
+ from transformers.modeling_outputs import ModelOutput
13
+
14
+
15
+ class ESMplusplusConfig(PretrainedConfig):
16
+ model_type = "ESMplusplus"
17
+ def __init__(
18
+ self,
19
+ vocab_size: int = 64,
20
+ hidden_size: int = 960,
21
+ num_attention_heads: int = 15,
22
+ num_hidden_layers: int = 30,
23
+ num_labels: int = 2,
24
+ problem_type: str | None = None,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.vocab_size = vocab_size
29
+ self.hidden_size = hidden_size
30
+ self.num_attention_heads = num_attention_heads
31
+ self.num_hidden_layers = num_hidden_layers
32
+ self.num_labels = num_labels
33
+ self.problem_type = problem_type
34
+
35
+
36
+ ### Rotary
37
+ # https://github.com/evolutionaryscale/esm/blob/main/esm/layers/rotary.py
38
+ # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114
39
+ # Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`
40
+ def rotate_half(x, interleaved=False):
41
+ if not interleaved:
42
+ x1, x2 = x.chunk(2, dim=-1)
43
+ return torch.cat((-x2, x1), dim=-1)
44
+ else:
45
+ x1, x2 = x[..., ::2], x[..., 1::2]
46
+ return rearrange(
47
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
48
+ )
49
+
50
+
51
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False):
52
+ """
53
+ x: (batch_size, seqlen, nheads, headdim)
54
+ cos, sin: (seqlen, rotary_dim / 2)
55
+ """
56
+ ro_dim = cos.shape[-1] * 2
57
+ assert ro_dim <= x.shape[-1]
58
+ seqlen = x.size(1)
59
+ cos = cos[:seqlen]
60
+ sin = sin[:seqlen]
61
+ cos = repeat(cos, "s d -> s 1 (2 d)")
62
+ sin = repeat(sin, "s d -> s 1 (2 d)")
63
+ return torch.cat(
64
+ [
65
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
66
+ x[..., ro_dim:],
67
+ ],
68
+ dim=-1,
69
+ )
70
+
71
+
72
+ class RotaryEmbedding(torch.nn.Module):
73
+ def __init__(
74
+ self,
75
+ dim: int,
76
+ base=10000.0,
77
+ interleaved=False,
78
+ scale_base=None,
79
+ scaling_factor=1.0,
80
+ pos_idx_in_fp32=True,
81
+ device=None,
82
+ ):
83
+ super().__init__()
84
+ self.dim = dim
85
+ self.base = float(base)
86
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
87
+ # Generate and save the inverse frequency buffer (non trainable)
88
+ self.interleaved = interleaved
89
+ self.scale_base = scale_base
90
+ self.scaling_factor = scaling_factor
91
+ self.device = device
92
+
93
+ self._seq_len_cached = 0
94
+ self._cos_cached = None
95
+ self._sin_cached = None
96
+ self._cos_k_cached = None
97
+ self._sin_k_cached = None
98
+ self.reset_parameters()
99
+
100
+ def reset_parameters(self):
101
+ inv_freq = self._compute_inv_freq(self.device)
102
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
103
+ arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
104
+ scale = (
105
+ (arange + 0.4 * self.dim) / (1.4 * self.dim)
106
+ if self.scale_base is not None
107
+ else None
108
+ )
109
+ self.register_buffer("scale", scale)
110
+
111
+ def _compute_inv_freq(self, device=None):
112
+ return 1 / (
113
+ self.base
114
+ ** (
115
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
116
+ / self.dim
117
+ )
118
+ )
119
+
120
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
121
+ if (
122
+ seqlen > self._seq_len_cached
123
+ or self._cos_cached is None
124
+ or self._cos_cached.device != device
125
+ or self._cos_cached.dtype != dtype
126
+ or (self.training and self._cos_cached.is_inference())
127
+ ):
128
+ self._seq_len_cached = seqlen
129
+ if self.pos_idx_in_fp32:
130
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
131
+ t /= self.scaling_factor
132
+ if self.inv_freq.dtype != torch.float32:
133
+ inv_freq = self.inv_freq.to(torch.float32)
134
+ else:
135
+ inv_freq = self.inv_freq
136
+ else:
137
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
138
+ t /= self.scaling_factor
139
+ inv_freq = self.inv_freq
140
+ freqs = torch.outer(t, inv_freq)
141
+
142
+ if self.scale is None:
143
+ self._cos_cached = torch.cos(freqs).to(dtype)
144
+ self._sin_cached = torch.sin(freqs).to(dtype)
145
+ else:
146
+ power = (
147
+ torch.arange(
148
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
149
+ )
150
+ - seqlen // 2
151
+ ) / self.scale_base
152
+ scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
153
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
154
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
155
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
156
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
157
+
158
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
159
+ """
160
+ q: (batch, seqlen, nheads, headdim)
161
+ k: (batch, seqlen, nheads, headdim)
162
+ """
163
+ self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
164
+ assert self._cos_cached is not None
165
+ assert self._sin_cached is not None
166
+ if self.scale is None:
167
+ return (
168
+ apply_rotary_emb_torch(
169
+ q,
170
+ self._cos_cached,
171
+ self._sin_cached,
172
+ self.interleaved,
173
+ True, # inplace=True
174
+ ),
175
+ apply_rotary_emb_torch(
176
+ k,
177
+ self._cos_cached,
178
+ self._sin_cached,
179
+ self.interleaved,
180
+ True, # inplace=True
181
+ ),
182
+ ) # type: ignore
183
+ else:
184
+ assert False
185
+
186
+
187
+ ### Feedforward
188
+ def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
189
+ return int(((expansion_ratio * d_model) + 255) // 256 * 256)
190
+
191
+
192
+ class SwiGLU(nn.Module):
193
+ def __init__(self):
194
+ super(SwiGLU, self).__init__()
195
+
196
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
197
+ x1, x2 = x.chunk(2, dim=-1)
198
+ return F.silu(x1) * x2
199
+
200
+
201
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float):
202
+ return nn.Sequential(
203
+ nn.LayerNorm(d_model),
204
+ nn.Linear(
205
+ d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
206
+ ),
207
+ SwiGLU(),
208
+ nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
209
+ )
210
+
211
+
212
+ ### Attention
213
+ class MultiHeadAttention(nn.Module):
214
+ def __init__(self, d_model: int, n_heads: int):
215
+ super().__init__()
216
+ self.d_model = d_model
217
+ self.n_heads = n_heads
218
+ self.d_head = self.d_model // self.n_heads
219
+ self.layernorm_qkv = nn.Sequential(
220
+ nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
221
+ )
222
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
223
+ self.q_ln = nn.LayerNorm(d_model, bias=False)
224
+ self.k_ln = nn.LayerNorm(d_model, bias=False)
225
+ self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
226
+ self.rotary = RotaryEmbedding(d_model // n_heads)
227
+
228
+ def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
229
+ q = q.unflatten(-1, (self.n_heads, self.d_head))
230
+ k = k.unflatten(-1, (self.n_heads, self.d_head))
231
+ q, k = self.rotary(q, k)
232
+ q = q.flatten(-2, -1)
233
+ k = k.flatten(-2, -1)
234
+ return q, k
235
+
236
+ def forward(self, x, attention_mask=None):
237
+ qkv_BLD3 = self.layernorm_qkv(x)
238
+ query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
239
+ query_BLD, key_BLD = (
240
+ self.q_ln(query_BLD).to(query_BLD.dtype),
241
+ self.k_ln(key_BLD).to(query_BLD.dtype),
242
+ )
243
+ query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
244
+ query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
245
+ context_BHLD = F.scaled_dot_product_attention(
246
+ query_BHLD, key_BHLD, value_BHLD, attention_mask
247
+ )
248
+ context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
249
+ return self.out_proj(context_BLD)
250
+
251
+
252
+ ### LM Head
253
+ def RegressionHead(
254
+ d_model: int, output_dim: int, hidden_dim: int | None = None
255
+ ) -> nn.Module:
256
+ hidden_dim = hidden_dim if hidden_dim is not None else d_model
257
+ return nn.Sequential(
258
+ nn.Linear(d_model, hidden_dim),
259
+ nn.GELU(),
260
+ nn.LayerNorm(hidden_dim),
261
+ nn.Linear(hidden_dim, output_dim),
262
+ )
263
+
264
+
265
+ ### Transformer Block
266
+ class UnifiedTransformerBlock(nn.Module):
267
+ def __init__(
268
+ self,
269
+ d_model: int,
270
+ n_heads: int,
271
+ residue_scaling_factor: float = 1,
272
+ expansion_ratio: float = 8 / 3,
273
+ ):
274
+ super().__init__()
275
+ self.attn = MultiHeadAttention(d_model, n_heads)
276
+ self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
277
+ self.scaling_factor = residue_scaling_factor
278
+
279
+ def forward(
280
+ self,
281
+ x: torch.Tensor,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ ) -> torch.Tensor:
284
+ r1 = self.attn(x, attention_mask)
285
+ x = x + r1 / self.scaling_factor
286
+ r3 = self.ffn(x) / self.scaling_factor
287
+ x = x + r3
288
+ return x
289
+
290
+
291
+ ### Outputs
292
+ @dataclass
293
+ class TransformerOutput(ModelOutput):
294
+ last_hidden_state: torch.Tensor | None = None
295
+ hidden_states: tuple[torch.Tensor] | None = None
296
+
297
+
298
+ @dataclass
299
+ class ESMplusplusOutput(ModelOutput):
300
+ loss: torch.Tensor | None = None
301
+ logits: torch.Tensor | None = None
302
+ last_hidden_state: torch.Tensor | None = None
303
+ hidden_states: tuple[torch.Tensor] | None = None
304
+
305
+
306
+ ### Transformer
307
+ class TransformerStack(nn.Module):
308
+ def __init__(
309
+ self,
310
+ d_model: int,
311
+ n_heads: int,
312
+ n_layers: int,
313
+ ):
314
+ super().__init__()
315
+ self.blocks = nn.ModuleList(
316
+ [
317
+ UnifiedTransformerBlock(
318
+ d_model,
319
+ n_heads,
320
+ residue_scaling_factor=math.sqrt(n_layers / 36),
321
+ )
322
+ for i in range(n_layers)
323
+ ]
324
+ )
325
+ self.norm = nn.LayerNorm(d_model, bias=False)
326
+
327
+ def forward(
328
+ self,
329
+ x: torch.Tensor,
330
+ attention_mask: Optional[torch.Tensor] = None,
331
+ output_hidden_states: bool = False,
332
+ ) -> TransformerOutput:
333
+ batch_size, seq_len, _ = x.shape
334
+ hidden_states = ()
335
+ if attention_mask is not None:
336
+ attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
337
+ for block in self.blocks:
338
+ x = block(x, attention_mask)
339
+ if output_hidden_states:
340
+ hidden_states += (x,)
341
+ return TransformerOutput(last_hidden_state=self.norm(x), hidden_states=hidden_states)
342
+
343
+
344
+ ### Full model
345
+ class ESMplusplusForMaskedLM(PreTrainedModel):
346
+ """
347
+ ESM++ for masked language modeling.
348
+ """
349
+ config_class = ESMplusplusConfig
350
+ def __init__(self, config: ESMplusplusConfig):
351
+ super().__init__(config)
352
+ self.config = config
353
+ self.vocab_size = config.vocab_size
354
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
355
+ self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers)
356
+ self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
357
+ self.ce_loss = nn.CrossEntropyLoss()
358
+ self.tokenizer = EsmSequenceTokenizer()
359
+
360
+ @classmethod
361
+ def from_pretrained_esm(cls, model_name: str):
362
+ if '300' in model_name:
363
+ return ESMplusplus_300M()
364
+ elif '600' in model_name:
365
+ return ESMplusplus_600M()
366
+ else:
367
+ raise ValueError(f"Invalid model name: {model_name}")
368
+
369
+ @property
370
+ def device(self):
371
+ return next(self.parameters()).device
372
+
373
+ def forward(
374
+ self,
375
+ input_ids: torch.Tensor | None = None,
376
+ attention_mask: Optional[torch.Tensor] = None,
377
+ labels: Optional[torch.Tensor] = None,
378
+ output_hidden_states: bool = False,
379
+ ) -> ESMplusplusOutput:
380
+ x = self.embed(input_ids)
381
+ output = self.transformer(x, attention_mask, output_hidden_states)
382
+ x = output.last_hidden_state
383
+ logits = self.sequence_head(x)
384
+ loss = None
385
+ if labels is not None:
386
+ loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
387
+ return ESMplusplusOutput(
388
+ loss=loss,
389
+ logits=logits,
390
+ last_hidden_state=x,
391
+ hidden_states=output.hidden_states,
392
+ )
393
+
394
+
395
+ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
396
+ """
397
+ ESM++ for sequence classification.
398
+ """
399
+ def __init__(self, config: ESMplusplusConfig):
400
+ super().__init__(config)
401
+ self.config = config
402
+ self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
403
+ # we find that large intermediate projections help with sequence classification tasks (*4)
404
+ self.mse = nn.MSELoss()
405
+ self.ce = nn.CrossEntropyLoss()
406
+ self.bce = nn.BCEWithLogitsLoss()
407
+
408
+ def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
409
+ # x: (batch_size, seq_len, hidden_size)
410
+ # attention_mask: (batch_size, seq_len)
411
+ if attention_mask is None:
412
+ return x.mean(dim=1)
413
+ else:
414
+ attention_mask = attention_mask.unsqueeze(-1)
415
+ return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
416
+
417
+ def forward(
418
+ self,
419
+ input_ids: torch.Tensor | None = None,
420
+ attention_mask: Optional[torch.Tensor] = None,
421
+ labels: Optional[torch.Tensor] = None,
422
+ output_hidden_states: bool = False,
423
+ ) -> ESMplusplusOutput:
424
+ output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
425
+ x = output.last_hidden_state
426
+ cls_features = x[:, 0, :]
427
+ mean_features = self.mean_pooling(x, attention_mask)
428
+ # we include mean pooling features to help with early convergence, the cost of this is basically zero
429
+ features = torch.cat([cls_features, mean_features], dim=-1)
430
+ logits = self.classifier(features)
431
+ loss = None
432
+ if labels is not None:
433
+ labels = labels.to(logits.device)
434
+ if self.config.problem_type is None:
435
+ if self.num_labels == 1:
436
+ self.config.problem_type = "regression"
437
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
438
+ self.config.problem_type = "single_label_classification"
439
+ else:
440
+ self.config.problem_type = "multi_label_classification"
441
+
442
+ if self.config.problem_type == "regression":
443
+ if self.num_labels == 1:
444
+ loss = self.mse(logits.flatten(), labels.flatten())
445
+ else:
446
+ loss = self.mse(logits, labels)
447
+ elif self.config.problem_type == "single_label_classification":
448
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
449
+ elif self.config.problem_type == "multi_label_classification":
450
+ loss = self.bce(logits, labels)
451
+ return ESMplusplusOutput(
452
+ loss=loss,
453
+ logits=logits,
454
+ last_hidden_state=x,
455
+ hidden_states=output.hidden_states,
456
+ )
457
+
458
+
459
+ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
460
+ """
461
+ ESM++ for token classification.
462
+ """
463
+ def __init__(self, config: ESMplusplusConfig):
464
+ super().__init__(config)
465
+ self.config = config
466
+ self.num_labels = config.num_labels
467
+ self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
468
+ # we find that large intermediate projections help with sequence classification tasks (*4)
469
+ self.loss_fct = nn.CrossEntropyLoss()
470
+
471
+ def forward(
472
+ self,
473
+ input_ids: torch.Tensor | None = None,
474
+ attention_mask: Optional[torch.Tensor] = None,
475
+ labels: Optional[torch.Tensor] = None,
476
+ output_hidden_states: bool = False,
477
+ ) -> ESMplusplusOutput:
478
+ output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
479
+ x = output.last_hidden_state
480
+ logits = self.classifier(x)
481
+ loss = None
482
+ if labels is not None:
483
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
484
+ return ESMplusplusOutput(
485
+ loss=loss,
486
+ logits=logits,
487
+ last_hidden_state=x,
488
+ hidden_states=output.hidden_states,
489
+ )
490
+
491
+
492
+ ### Loading
493
+ import os
494
+ from functools import cache
495
+ from pathlib import Path
496
+ from huggingface_hub import snapshot_download
497
+
498
+
499
+ @staticmethod
500
+ @cache
501
+ def data_root(model: str):
502
+ if "INFRA_PROVIDER" in os.environ:
503
+ return Path("")
504
+ # Try to download from hugginface if it doesn't exist
505
+ if model.startswith("esmc-300"):
506
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
507
+ elif model.startswith("esmc-600"):
508
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
509
+ else:
510
+ raise ValueError(f"{model=} is an invalid model name.")
511
+ return path
512
+
513
+
514
+ def ESMplusplus_300M(device: torch.device | str = "cpu"):
515
+ with torch.device(device):
516
+ config = ESMplusplusConfig(
517
+ hidden_size=960,
518
+ num_attention_heads=15,
519
+ num_hidden_layers=30,
520
+ )
521
+ model = ESMplusplusForMaskedLM(config)
522
+ state_dict = torch.load(
523
+ data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
524
+ map_location=device,
525
+ )
526
+ model.load_state_dict(state_dict)
527
+ return model
528
+
529
+
530
+ def ESMplusplus_600M(device: torch.device | str = "cpu"):
531
+ with torch.device(device):
532
+ config = ESMplusplusConfig(
533
+ hidden_size=1152,
534
+ num_attention_heads=18,
535
+ num_hidden_layers=36,
536
+ )
537
+ model = ESMplusplusForMaskedLM(config)
538
+ state_dict = torch.load(
539
+ data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
540
+ map_location=device,
541
+ )
542
+ model.load_state_dict(state_dict)
543
+ return model
544
+
545
+
546
+ ### Tokenization
547
+ from tokenizers import Tokenizer
548
+ from tokenizers.models import BPE
549
+ from tokenizers.processors import TemplateProcessing
550
+ from transformers import PreTrainedTokenizerFast
551
+
552
+
553
+ SEQUENCE_VOCAB = [
554
+ "<cls>", "<pad>", "<eos>", "<unk>",
555
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
556
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
557
+ "O", ".", "-", "|",
558
+ "<mask>",
559
+ ]
560
+
561
+ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
562
+ model_input_names = ["input_ids", "attention_mask"]
563
+
564
+ def __init__(
565
+ self,
566
+ unk_token="<unk>",
567
+ cls_token="<cls>",
568
+ pad_token="<pad>",
569
+ mask_token="<mask>",
570
+ eos_token="<eos>",
571
+ chain_break_token="|",
572
+ **kwargs,
573
+ ):
574
+ all_tokens = SEQUENCE_VOCAB
575
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
576
+
577
+ # a character-level tokenizer is the same as BPE with no token merges
578
+ bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
579
+ tokenizer = Tokenizer(bpe)
580
+ special_tokens = [
581
+ cls_token,
582
+ pad_token,
583
+ mask_token,
584
+ eos_token,
585
+ chain_break_token,
586
+ ]
587
+ self.cb_token = chain_break_token
588
+ additional_special_tokens = [chain_break_token]
589
+
590
+ tokenizer.add_special_tokens(special_tokens)
591
+
592
+ # This is where we configure the automatic addition of special tokens when we call
593
+ # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
594
+ # sequences are merged if you want.
595
+ tokenizer.post_processor = TemplateProcessing( # type: ignore
596
+ single="<cls> $A <eos>",
597
+ special_tokens=[
598
+ ("<cls>", tokenizer.token_to_id("<cls>")),
599
+ ("<eos>", tokenizer.token_to_id("<eos>")),
600
+ ],
601
+ )
602
+ super().__init__(
603
+ tokenizer_object=tokenizer,
604
+ unk_token=unk_token,
605
+ cls_token=cls_token,
606
+ pad_token=pad_token,
607
+ mask_token=mask_token,
608
+ eos_token=eos_token,
609
+ additional_special_tokens=additional_special_tokens,
610
+ **kwargs,
611
+ )
612
+
613
+ # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
614
+ @property
615
+ def bos_token(self):
616
+ return self.cls_token
617
+
618
+ @property
619
+ def bos_token_id(self):
620
+ return self.cls_token_id
621
+
622
+ @property
623
+ def chain_break_token(self):
624
+ return self.cb_token
625
+
626
+ @property
627
+ def chain_break_token_id(self):
628
+ return self.convert_tokens_to_ids(self.chain_break_token)
629
+
630
+ @property
631
+ def all_token_ids(self):
632
+ return list(range(self.vocab_size))
633
+
634
+ @property
635
+ def special_token_ids(self):
636
+ return self.all_special_ids