lhallee commited on
Commit
7e7ee15
·
verified ·
1 Parent(s): f456813

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +326 -78
modeling_esm_plusplus.py CHANGED
@@ -1,18 +1,44 @@
1
- ### Modified from https://github.com/evolutionaryscale/esm
2
- ### License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
- import math
7
  from dataclasses import dataclass
8
- from transformers import PreTrainedModel, PretrainedConfig
9
- from einops import rearrange, repeat
10
- from functools import partial
11
  from typing import Optional, Tuple
 
 
 
 
 
 
 
 
12
  from transformers.modeling_outputs import ModelOutput
13
 
14
 
15
  class ESMplusplusConfig(PretrainedConfig):
 
 
 
 
 
 
 
 
 
 
16
  model_type = "ESMplusplus"
17
  def __init__(
18
  self,
@@ -33,11 +59,9 @@ class ESMplusplusConfig(PretrainedConfig):
33
  self.problem_type = problem_type
34
 
35
 
36
- ### Rotary
37
- # https://github.com/evolutionaryscale/esm/blob/main/esm/layers/rotary.py
38
- # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114
39
- # Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`
40
- def rotate_half(x, interleaved=False):
41
  if not interleaved:
42
  x1, x2 = x.chunk(2, dim=-1)
43
  return torch.cat((-x2, x1), dim=-1)
@@ -48,11 +72,14 @@ def rotate_half(x, interleaved=False):
48
  )
49
 
50
 
51
- def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False):
52
- """
53
- x: (batch_size, seqlen, nheads, headdim)
54
- cos, sin: (seqlen, rotary_dim / 2)
55
- """
 
 
 
56
  ro_dim = cos.shape[-1] * 2
57
  assert ro_dim <= x.shape[-1]
58
  seqlen = x.size(1)
@@ -70,21 +97,33 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False):
70
 
71
 
72
  class RotaryEmbedding(torch.nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def __init__(
74
  self,
75
  dim: int,
76
- base=10000.0,
77
- interleaved=False,
78
- scale_base=None,
79
- scaling_factor=1.0,
80
- pos_idx_in_fp32=True,
81
- device=None,
82
  ):
83
  super().__init__()
84
  self.dim = dim
85
  self.base = float(base)
86
  self.pos_idx_in_fp32 = pos_idx_in_fp32
87
- # Generate and save the inverse frequency buffer (non trainable)
88
  self.interleaved = interleaved
89
  self.scale_base = scale_base
90
  self.scaling_factor = scaling_factor
@@ -98,6 +137,7 @@ class RotaryEmbedding(torch.nn.Module):
98
  self.reset_parameters()
99
 
100
  def reset_parameters(self):
 
101
  inv_freq = self._compute_inv_freq(self.device)
102
  self.register_buffer("inv_freq", inv_freq, persistent=False)
103
  arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
@@ -108,7 +148,8 @@ class RotaryEmbedding(torch.nn.Module):
108
  )
109
  self.register_buffer("scale", scale)
110
 
111
- def _compute_inv_freq(self, device=None):
 
112
  return 1 / (
113
  self.base
114
  ** (
@@ -117,7 +158,8 @@ class RotaryEmbedding(torch.nn.Module):
117
  )
118
  )
119
 
120
- def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
 
121
  if (
122
  seqlen > self._seq_len_cached
123
  or self._cos_cached is None
@@ -156,9 +198,14 @@ class RotaryEmbedding(torch.nn.Module):
156
  self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
157
 
158
  def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
159
- """
160
- q: (batch, seqlen, nheads, headdim)
161
- k: (batch, seqlen, nheads, headdim)
 
 
 
 
 
162
  """
163
  self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
164
  assert self._cos_cached is not None
@@ -184,12 +231,14 @@ class RotaryEmbedding(torch.nn.Module):
184
  assert False
185
 
186
 
187
- ### Feedforward
188
  def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
 
189
  return int(((expansion_ratio * d_model) + 255) // 256 * 256)
190
 
191
 
192
  class SwiGLU(nn.Module):
 
193
  def __init__(self):
194
  super(SwiGLU, self).__init__()
195
 
@@ -198,7 +247,8 @@ class SwiGLU(nn.Module):
198
  return F.silu(x1) * x2
199
 
200
 
201
- def swiglu_ln_ffn(d_model: int, expansion_ratio: float):
 
202
  return nn.Sequential(
203
  nn.LayerNorm(d_model),
204
  nn.Linear(
@@ -211,6 +261,12 @@ def swiglu_ln_ffn(d_model: int, expansion_ratio: float):
211
 
212
  ### Attention
213
  class MultiHeadAttention(nn.Module):
 
 
 
 
 
 
214
  def __init__(self, d_model: int, n_heads: int):
215
  super().__init__()
216
  self.d_model = d_model
@@ -225,7 +281,8 @@ class MultiHeadAttention(nn.Module):
225
  self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
226
  self.rotary = RotaryEmbedding(d_model // n_heads)
227
 
228
- def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
 
229
  q = q.unflatten(-1, (self.n_heads, self.d_head))
230
  k = k.unflatten(-1, (self.n_heads, self.d_head))
231
  q, k = self.rotary(q, k)
@@ -233,7 +290,15 @@ class MultiHeadAttention(nn.Module):
233
  k = k.flatten(-2, -1)
234
  return q, k
235
 
236
- def forward(self, x, attention_mask=None):
 
 
 
 
 
 
 
 
237
  qkv_BLD3 = self.layernorm_qkv(x)
238
  query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
239
  query_BLD, key_BLD = (
@@ -249,10 +314,17 @@ class MultiHeadAttention(nn.Module):
249
  return self.out_proj(context_BLD)
250
 
251
 
252
- ### LM Head
253
  def RegressionHead(
254
- d_model: int, output_dim: int, hidden_dim: int | None = None
255
  ) -> nn.Module:
 
 
 
 
 
 
 
256
  hidden_dim = hidden_dim if hidden_dim is not None else d_model
257
  return nn.Sequential(
258
  nn.Linear(d_model, hidden_dim),
@@ -264,6 +336,14 @@ def RegressionHead(
264
 
265
  ### Transformer Block
266
  class UnifiedTransformerBlock(nn.Module):
 
 
 
 
 
 
 
 
267
  def __init__(
268
  self,
269
  d_model: int,
@@ -281,6 +361,14 @@ class UnifiedTransformerBlock(nn.Module):
281
  x: torch.Tensor,
282
  attention_mask: Optional[torch.Tensor] = None,
283
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
284
  r1 = self.attn(x, attention_mask)
285
  x = x + r1 / self.scaling_factor
286
  r3 = self.ffn(x) / self.scaling_factor
@@ -288,23 +376,32 @@ class UnifiedTransformerBlock(nn.Module):
288
  return x
289
 
290
 
291
- ### Outputs
292
  @dataclass
293
  class TransformerOutput(ModelOutput):
294
- last_hidden_state: torch.Tensor | None = None
295
- hidden_states: tuple[torch.Tensor] | None = None
 
296
 
297
 
298
  @dataclass
299
  class ESMplusplusOutput(ModelOutput):
300
- loss: torch.Tensor | None = None
301
- logits: torch.Tensor | None = None
302
- last_hidden_state: torch.Tensor | None = None
303
- hidden_states: tuple[torch.Tensor] | None = None
 
304
 
305
 
306
- ### Transformer
307
  class TransformerStack(nn.Module):
 
 
 
 
 
 
 
308
  def __init__(
309
  self,
310
  d_model: int,
@@ -330,6 +427,15 @@ class TransformerStack(nn.Module):
330
  attention_mask: Optional[torch.Tensor] = None,
331
  output_hidden_states: bool = False,
332
  ) -> TransformerOutput:
 
 
 
 
 
 
 
 
 
333
  batch_size, seq_len, _ = x.shape
334
  hidden_states = ()
335
  if attention_mask is not None:
@@ -341,10 +447,24 @@ class TransformerStack(nn.Module):
341
  return TransformerOutput(last_hidden_state=self.norm(x), hidden_states=hidden_states)
342
 
343
 
344
- ### Full model
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  class ESMplusplusForMaskedLM(PreTrainedModel):
346
- """
347
- ESM++ for masked language modeling.
 
348
  """
349
  config_class = ESMplusplusConfig
350
  def __init__(self, config: ESMplusplusConfig):
@@ -358,7 +478,8 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
358
  self.tokenizer = EsmSequenceTokenizer()
359
 
360
  @classmethod
361
- def from_pretrained_esm(cls, model_name: str):
 
362
  if '300' in model_name:
363
  return ESMplusplus_300M()
364
  elif '600' in model_name:
@@ -367,16 +488,140 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
367
  raise ValueError(f"Invalid model name: {model_name}")
368
 
369
  @property
370
- def device(self):
 
371
  return next(self.parameters()).device
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  def forward(
374
  self,
375
- input_ids: torch.Tensor | None = None,
376
  attention_mask: Optional[torch.Tensor] = None,
377
  labels: Optional[torch.Tensor] = None,
378
  output_hidden_states: bool = False,
379
  ) -> ESMplusplusOutput:
 
 
 
 
 
 
 
 
 
 
 
380
  x = self.embed(input_ids)
381
  output = self.transformer(x, attention_mask, output_hidden_states)
382
  x = output.last_hidden_state
@@ -393,34 +638,37 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
393
 
394
 
395
  class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
396
- """
397
- ESM++ for sequence classification.
 
398
  """
399
  def __init__(self, config: ESMplusplusConfig):
400
  super().__init__(config)
401
  self.config = config
402
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
403
- # we find that large intermediate projections help with sequence classification tasks (*4)
404
  self.mse = nn.MSELoss()
405
  self.ce = nn.CrossEntropyLoss()
406
  self.bce = nn.BCEWithLogitsLoss()
407
 
408
- def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
409
- # x: (batch_size, seq_len, hidden_size)
410
- # attention_mask: (batch_size, seq_len)
411
- if attention_mask is None:
412
- return x.mean(dim=1)
413
- else:
414
- attention_mask = attention_mask.unsqueeze(-1)
415
- return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
416
-
417
  def forward(
418
  self,
419
- input_ids: torch.Tensor | None = None,
420
  attention_mask: Optional[torch.Tensor] = None,
421
  labels: Optional[torch.Tensor] = None,
422
  output_hidden_states: bool = False,
423
  ) -> ESMplusplusOutput:
 
 
 
 
 
 
 
 
 
 
 
424
  output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
425
  x = output.last_hidden_state
426
  cls_features = x[:, 0, :]
@@ -457,24 +705,36 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
457
 
458
 
459
  class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
460
- """
461
- ESM++ for token classification.
 
462
  """
463
  def __init__(self, config: ESMplusplusConfig):
464
  super().__init__(config)
465
  self.config = config
466
  self.num_labels = config.num_labels
467
  self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
468
- # we find that large intermediate projections help with sequence classification tasks (*4)
469
  self.loss_fct = nn.CrossEntropyLoss()
470
 
471
  def forward(
472
  self,
473
- input_ids: torch.Tensor | None = None,
474
  attention_mask: Optional[torch.Tensor] = None,
475
  labels: Optional[torch.Tensor] = None,
476
  output_hidden_states: bool = False,
477
  ) -> ESMplusplusOutput:
 
 
 
 
 
 
 
 
 
 
 
478
  output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
479
  x = output.last_hidden_state
480
  logits = self.classifier(x)
@@ -489,13 +749,7 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
489
  )
490
 
491
 
492
- ### Loading
493
- import os
494
- from functools import cache
495
- from pathlib import Path
496
- from huggingface_hub import snapshot_download
497
-
498
-
499
  @staticmethod
500
  @cache
501
  def data_root(model: str):
@@ -544,12 +798,6 @@ def ESMplusplus_600M(device: torch.device | str = "cpu"):
544
 
545
 
546
  ### Tokenization
547
- from tokenizers import Tokenizer
548
- from tokenizers.models import BPE
549
- from tokenizers.processors import TemplateProcessing
550
- from transformers import PreTrainedTokenizerFast
551
-
552
-
553
  SEQUENCE_VOCAB = [
554
  "<cls>", "<pad>", "<eos>", "<unk>",
555
  "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
 
1
+ """
2
+ ESM++ model implementation.
3
+
4
+ ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility
5
+ The ESM Python package is not required
6
+
7
+ Modified from https://github.com/evolutionaryscale/esm
8
+ License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
9
+ """
10
+
11
+ import math
12
+ import os
13
  import torch
14
  import torch.nn as nn
15
  import torch.nn.functional as F
 
16
  from dataclasses import dataclass
17
+ from functools import cache, partial
18
+ from pathlib import Path
 
19
  from typing import Optional, Tuple
20
+ from einops import rearrange, repeat
21
+ from huggingface_hub import snapshot_download
22
+ from tokenizers import Tokenizer
23
+ from tokenizers.models import BPE
24
+ from tokenizers.processors import TemplateProcessing
25
+ from torch.utils.data import Dataset, DataLoader
26
+ from tqdm.auto import tqdm
27
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
28
  from transformers.modeling_outputs import ModelOutput
29
 
30
 
31
  class ESMplusplusConfig(PretrainedConfig):
32
+ """Configuration class for ESM++ model.
33
+
34
+ Args:
35
+ vocab_size: Size of the vocabulary
36
+ hidden_size: Dimension of hidden layers
37
+ num_attention_heads: Number of attention heads
38
+ num_hidden_layers: Number of transformer layers
39
+ num_labels: Number of output labels for classification
40
+ problem_type: Type of problem - regression, single/multi label classification
41
+ """
42
  model_type = "ESMplusplus"
43
  def __init__(
44
  self,
 
59
  self.problem_type = problem_type
60
 
61
 
62
+ ### Rotary Embeddings
63
+ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
64
+ """Rotates half the hidden dims of the input."""
 
 
65
  if not interleaved:
66
  x1, x2 = x.chunk(2, dim=-1)
67
  return torch.cat((-x2, x1), dim=-1)
 
72
  )
73
 
74
 
75
+ def apply_rotary_emb_torch(
76
+ x: torch.Tensor,
77
+ cos: torch.Tensor,
78
+ sin: torch.Tensor,
79
+ interleaved: bool = False,
80
+ _inplace: bool = False,
81
+ ) -> torch.Tensor:
82
+ """Apply rotary embeddings to input based on cos and sin."""
83
  ro_dim = cos.shape[-1] * 2
84
  assert ro_dim <= x.shape[-1]
85
  seqlen = x.size(1)
 
97
 
98
 
99
  class RotaryEmbedding(torch.nn.Module):
100
+ """Rotary position embeddings.
101
+
102
+ Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding"
103
+
104
+ Args:
105
+ dim: Dimension of the embedding
106
+ base: Base for computing angular frequencies
107
+ interleaved: Whether to use interleaved rotations
108
+ scale_base: Base for scaling
109
+ scaling_factor: Factor for scaling positions
110
+ pos_idx_in_fp32: Whether to compute position indices in fp32
111
+ device: Computation device
112
+ """
113
  def __init__(
114
  self,
115
  dim: int,
116
+ base: float = 10000.0,
117
+ interleaved: bool = False,
118
+ scale_base: Optional[float] = None,
119
+ scaling_factor: float = 1.0,
120
+ pos_idx_in_fp32: bool = True,
121
+ device: Optional[torch.device] = None,
122
  ):
123
  super().__init__()
124
  self.dim = dim
125
  self.base = float(base)
126
  self.pos_idx_in_fp32 = pos_idx_in_fp32
 
127
  self.interleaved = interleaved
128
  self.scale_base = scale_base
129
  self.scaling_factor = scaling_factor
 
137
  self.reset_parameters()
138
 
139
  def reset_parameters(self):
140
+ """Reset the parameters of the embedding."""
141
  inv_freq = self._compute_inv_freq(self.device)
142
  self.register_buffer("inv_freq", inv_freq, persistent=False)
143
  arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
 
148
  )
149
  self.register_buffer("scale", scale)
150
 
151
+ def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
152
+ """Compute inverse frequency bands."""
153
  return 1 / (
154
  self.base
155
  ** (
 
158
  )
159
  )
160
 
161
+ def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
162
+ """Update the cached cosine and sine values."""
163
  if (
164
  seqlen > self._seq_len_cached
165
  or self._cos_cached is None
 
198
  self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
199
 
200
  def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
201
+ """Apply rotary embeddings to queries and keys.
202
+
203
+ Args:
204
+ q: Query tensor of shape (batch, seqlen, nheads, headdim)
205
+ k: Key tensor of shape (batch, seqlen, nheads, headdim)
206
+
207
+ Returns:
208
+ Tuple of rotated query and key tensors
209
  """
210
  self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
211
  assert self._cos_cached is not None
 
231
  assert False
232
 
233
 
234
+ ### Feedforward Network Components
235
  def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
236
+ """Compute corrected dimension for SwiGLU."""
237
  return int(((expansion_ratio * d_model) + 255) // 256 * 256)
238
 
239
 
240
  class SwiGLU(nn.Module):
241
+ """SwiGLU activation function."""
242
  def __init__(self):
243
  super(SwiGLU, self).__init__()
244
 
 
247
  return F.silu(x1) * x2
248
 
249
 
250
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
251
+ """Create SwiGLU feedforward network with layer normalization."""
252
  return nn.Sequential(
253
  nn.LayerNorm(d_model),
254
  nn.Linear(
 
261
 
262
  ### Attention
263
  class MultiHeadAttention(nn.Module):
264
+ """Multi-head attention with rotary embeddings.
265
+
266
+ Args:
267
+ d_model: Model dimension
268
+ n_heads: Number of attention heads
269
+ """
270
  def __init__(self, d_model: int, n_heads: int):
271
  super().__init__()
272
  self.d_model = d_model
 
281
  self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
282
  self.rotary = RotaryEmbedding(d_model // n_heads)
283
 
284
+ def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
285
+ """Apply rotary embeddings to query and key."""
286
  q = q.unflatten(-1, (self.n_heads, self.d_head))
287
  k = k.unflatten(-1, (self.n_heads, self.d_head))
288
  q, k = self.rotary(q, k)
 
290
  k = k.flatten(-2, -1)
291
  return q, k
292
 
293
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
294
+ """
295
+ Args:
296
+ x: Input tensor
297
+ attention_mask: Optional attention mask
298
+
299
+ Returns:
300
+ Output tensor after self attention
301
+ """
302
  qkv_BLD3 = self.layernorm_qkv(x)
303
  query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
304
  query_BLD, key_BLD = (
 
314
  return self.out_proj(context_BLD)
315
 
316
 
317
+ ### Regression Head
318
  def RegressionHead(
319
+ d_model: int, output_dim: int, hidden_dim: Optional[int] = None
320
  ) -> nn.Module:
321
+ """Create a regression head with optional hidden dimension.
322
+
323
+ Args:
324
+ d_model: Input dimension
325
+ output_dim: Output dimension
326
+ hidden_dim: Optional hidden dimension (defaults to d_model)
327
+ """
328
  hidden_dim = hidden_dim if hidden_dim is not None else d_model
329
  return nn.Sequential(
330
  nn.Linear(d_model, hidden_dim),
 
336
 
337
  ### Transformer Block
338
  class UnifiedTransformerBlock(nn.Module):
339
+ """Transformer block with attention and feedforward layers.
340
+
341
+ Args:
342
+ d_model: Model dimension
343
+ n_heads: Number of attention heads
344
+ residue_scaling_factor: Factor for scaling residual connections
345
+ expansion_ratio: Expansion ratio for feedforward network
346
+ """
347
  def __init__(
348
  self,
349
  d_model: int,
 
361
  x: torch.Tensor,
362
  attention_mask: Optional[torch.Tensor] = None,
363
  ) -> torch.Tensor:
364
+ """
365
+ Args:
366
+ x: Input tensor
367
+ attention_mask: Optional attention mask
368
+
369
+ Returns:
370
+ Output tensor after transformer block
371
+ """
372
  r1 = self.attn(x, attention_mask)
373
  x = x + r1 / self.scaling_factor
374
  r3 = self.ffn(x) / self.scaling_factor
 
376
  return x
377
 
378
 
379
+ ### Model Outputs
380
  @dataclass
381
  class TransformerOutput(ModelOutput):
382
+ """Output type for transformer encoder."""
383
+ last_hidden_state: Optional[torch.Tensor] = None
384
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
385
 
386
 
387
  @dataclass
388
  class ESMplusplusOutput(ModelOutput):
389
+ """Output type for ESM++ models."""
390
+ loss: Optional[torch.Tensor] = None
391
+ logits: Optional[torch.Tensor] = None
392
+ last_hidden_state: Optional[torch.Tensor] = None
393
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
394
 
395
 
396
+ ### Transformer Stack
397
  class TransformerStack(nn.Module):
398
+ """Stack of transformer blocks.
399
+
400
+ Args:
401
+ d_model: Model dimension
402
+ n_heads: Number of attention heads
403
+ n_layers: Number of transformer layers
404
+ """
405
  def __init__(
406
  self,
407
  d_model: int,
 
427
  attention_mask: Optional[torch.Tensor] = None,
428
  output_hidden_states: bool = False,
429
  ) -> TransformerOutput:
430
+ """
431
+ Args:
432
+ x: Input tensor
433
+ attention_mask: Optional attention mask
434
+ output_hidden_states: Whether to return all hidden states
435
+
436
+ Returns:
437
+ TransformerOutput containing last hidden state and optionally all hidden states
438
+ """
439
  batch_size, seq_len, _ = x.shape
440
  hidden_states = ()
441
  if attention_mask is not None:
 
447
  return TransformerOutput(last_hidden_state=self.norm(x), hidden_states=hidden_states)
448
 
449
 
450
+ ### Dataset for Embedding
451
+ class ProteinDataset(Dataset):
452
+ """Simple dataset for protein sequences."""
453
+ def __init__(self, sequences: list[str]):
454
+ self.sequences = sequences
455
+
456
+ def __len__(self) -> int:
457
+ return len(self.sequences)
458
+
459
+ def __getitem__(self, idx: int) -> str:
460
+ return self.sequences[idx]
461
+
462
+
463
+ ### ESM++ Models
464
  class ESMplusplusForMaskedLM(PreTrainedModel):
465
+ """ESM++ model for masked language modeling.
466
+
467
+ Implements the base ESM++ architecture with a masked language modeling head.
468
  """
469
  config_class = ESMplusplusConfig
470
  def __init__(self, config: ESMplusplusConfig):
 
478
  self.tokenizer = EsmSequenceTokenizer()
479
 
480
  @classmethod
481
+ def from_pretrained_esm(cls, model_name: str) -> "ESMplusplusForMaskedLM":
482
+ """Load a pretrained ESM++ model."""
483
  if '300' in model_name:
484
  return ESMplusplus_300M()
485
  elif '600' in model_name:
 
488
  raise ValueError(f"Invalid model name: {model_name}")
489
 
490
  @property
491
+ def device(self) -> torch.device:
492
+ """Get the device of the model."""
493
  return next(self.parameters()).device
494
 
495
+ def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
496
+ """Apply mean pooling to sequence outputs."""
497
+ if attention_mask is None:
498
+ return x.mean(dim=1)
499
+ else:
500
+ attention_mask = attention_mask.unsqueeze(-1)
501
+ return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
502
+
503
+ def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
504
+ """Collate function for batching sequences."""
505
+ return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
506
+
507
+ def _read_sequences_from_db(self, db_path: str) -> set[str]:
508
+ """Read sequences from SQLite database."""
509
+ import sqlite3
510
+ sequences = []
511
+ with sqlite3.connect(db_path) as conn:
512
+ c = conn.cursor()
513
+ c.execute("SELECT sequence FROM embeddings")
514
+ while True:
515
+ row = c.fetchone()
516
+ if row is None:
517
+ break
518
+ sequences.append(row[0])
519
+ return set(sequences)
520
+
521
+ def embed_dataset(
522
+ self,
523
+ sequences: list[str],
524
+ batch_size: int = 2,
525
+ max_len: int = 512,
526
+ full_embeddings: bool = False,
527
+ full_precision: bool = False,
528
+ pooling_type: str = 'mean',
529
+ num_workers: int = 0,
530
+ sql: bool = False,
531
+ sql_db_path: str = 'embeddings.db',
532
+ ) -> Optional[dict[str, torch.Tensor]]:
533
+ """Embed a dataset of protein sequences.
534
+
535
+ Args:
536
+ sequences: List of protein sequences
537
+ batch_size: Batch size for processing
538
+ max_len: Maximum sequence length
539
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
540
+ full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
541
+ pooling_type: Type of pooling ('mean' or 'cls')
542
+ num_workers: Number of workers for data loading, 0 for the main process
543
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
544
+ sql_db_path: Path to SQLite database
545
+
546
+ Returns:
547
+ Dictionary mapping sequences to embeddings, or None if sql=True
548
+ """
549
+ sequences = list(set([seq[:max_len] for seq in sequences]))
550
+ sequences = sorted(sequences, key=len, reverse=True)
551
+ dataset = ProteinDataset(sequences)
552
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn)
553
+ device = self.device
554
+
555
+ def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
556
+ if full_embeddings:
557
+ return residue_embeddings
558
+ elif pooling_type == 'mean':
559
+ return self.mean_pooling(residue_embeddings, attention_mask)
560
+ else:
561
+ return residue_embeddings[:, 0, :]
562
+
563
+ if sql:
564
+ import sqlite3
565
+ conn = sqlite3.connect(sql_db_path)
566
+ c = conn.cursor()
567
+ c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
568
+ already_embedded = self._read_sequences_from_db(sql_db_path)
569
+ to_embed = [seq for seq in sequences if seq not in already_embedded]
570
+ print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
571
+ print(f"Embedding {len(to_embed)} new sequences")
572
+
573
+ with torch.no_grad():
574
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
575
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
576
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
577
+ x = self.embed(input_ids)
578
+ residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.float() # required for sql
579
+ embeddings = get_embeddings(residue_embeddings, attention_mask)
580
+
581
+ for seq, emb in zip(seqs, embeddings):
582
+ c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
583
+ (seq, emb.cpu().numpy().tobytes()))
584
+
585
+ if (i + 1) % 100 == 0:
586
+ conn.commit()
587
+
588
+ conn.commit()
589
+ conn.close()
590
+ return None
591
+
592
+ embeddings_dict = {}
593
+ with torch.no_grad():
594
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
595
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
596
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
597
+ x = self.embed(input_ids)
598
+ residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
599
+ if full_precision:
600
+ residue_embeddings = residue_embeddings.float()
601
+ embeddings = get_embeddings(residue_embeddings, attention_mask)
602
+ for seq, emb in zip(seqs, embeddings):
603
+ embeddings_dict[seq] = emb
604
+
605
+ return embeddings_dict
606
+
607
  def forward(
608
  self,
609
+ input_ids: Optional[torch.Tensor] = None,
610
  attention_mask: Optional[torch.Tensor] = None,
611
  labels: Optional[torch.Tensor] = None,
612
  output_hidden_states: bool = False,
613
  ) -> ESMplusplusOutput:
614
+ """Forward pass for masked language modeling.
615
+
616
+ Args:
617
+ input_ids: Input token IDs
618
+ attention_mask: Attention mask
619
+ labels: Optional labels for masked tokens
620
+ output_hidden_states: Whether to return all hidden states
621
+
622
+ Returns:
623
+ ESMplusplusOutput containing loss, logits, and hidden states
624
+ """
625
  x = self.embed(input_ids)
626
  output = self.transformer(x, attention_mask, output_hidden_states)
627
  x = output.last_hidden_state
 
638
 
639
 
640
  class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
641
+ """ESM++ model for sequence classification.
642
+
643
+ Extends the base ESM++ model with a classification head.
644
  """
645
  def __init__(self, config: ESMplusplusConfig):
646
  super().__init__(config)
647
  self.config = config
648
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
649
+ # Large intermediate projections help with sequence classification tasks (*4)
650
  self.mse = nn.MSELoss()
651
  self.ce = nn.CrossEntropyLoss()
652
  self.bce = nn.BCEWithLogitsLoss()
653
 
 
 
 
 
 
 
 
 
 
654
  def forward(
655
  self,
656
+ input_ids: Optional[torch.Tensor] = None,
657
  attention_mask: Optional[torch.Tensor] = None,
658
  labels: Optional[torch.Tensor] = None,
659
  output_hidden_states: bool = False,
660
  ) -> ESMplusplusOutput:
661
+ """Forward pass for sequence classification.
662
+
663
+ Args:
664
+ input_ids: Input token IDs
665
+ attention_mask: Attention mask
666
+ labels: Optional labels for classification
667
+ output_hidden_states: Whether to return all hidden states
668
+
669
+ Returns:
670
+ ESMplusplusOutput containing loss, logits, and hidden states
671
+ """
672
  output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
673
  x = output.last_hidden_state
674
  cls_features = x[:, 0, :]
 
705
 
706
 
707
  class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
708
+ """ESM++ model for token classification.
709
+
710
+ Extends the base ESM++ model with a token classification head.
711
  """
712
  def __init__(self, config: ESMplusplusConfig):
713
  super().__init__(config)
714
  self.config = config
715
  self.num_labels = config.num_labels
716
  self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
717
+ # Large intermediate projections help with sequence classification tasks (*4)
718
  self.loss_fct = nn.CrossEntropyLoss()
719
 
720
  def forward(
721
  self,
722
+ input_ids: Optional[torch.Tensor] = None,
723
  attention_mask: Optional[torch.Tensor] = None,
724
  labels: Optional[torch.Tensor] = None,
725
  output_hidden_states: bool = False,
726
  ) -> ESMplusplusOutput:
727
+ """Forward pass for token classification.
728
+
729
+ Args:
730
+ input_ids: Input token IDs
731
+ attention_mask: Attention mask
732
+ labels: Optional labels for token classification
733
+ output_hidden_states: Whether to return all hidden states
734
+
735
+ Returns:
736
+ ESMplusplusOutput containing loss, logits, and hidden states
737
+ """
738
  output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
739
  x = output.last_hidden_state
740
  logits = self.classifier(x)
 
749
  )
750
 
751
 
752
+ ### Loading from EvolutionaryScale
 
 
 
 
 
 
753
  @staticmethod
754
  @cache
755
  def data_root(model: str):
 
798
 
799
 
800
  ### Tokenization
 
 
 
 
 
 
801
  SEQUENCE_VOCAB = [
802
  "<cls>", "<pad>", "<eos>", "<unk>",
803
  "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",