File size: 6,331 Bytes
d3fe818
 
 
 
 
f456813
01d5f34
f456813
d3fe818
f456813
 
fbb9df7
60cc6d7
f456813
fbb9df7
d3fe818
fbb9df7
f456813
d3fe818
fbb9df7
d3fe818
fbb9df7
 
f456813
fbb9df7
 
 
d3fe818
fbb9df7
d3fe818
fbb9df7
 
d3fe818
f456813
 
fbb9df7
 
d3fe818
f456813
 
 
 
 
 
0cf4e7f
86e0aab
0cf4e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e13d7d8
 
 
 
 
 
 
 
 
 
 
f456813
 
 
 
 
 
 
 
d3fe818
f456813
d3fe818
f456813
 
 
3f46b49
f456813
 
0bae73f
f456813
 
6436b92
 
f456813
 
1d03df2
f456813
 
 
 
 
 
 
4d80e4b
f456813
 
005c0c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
---
library_name: transformers
tags: []
---

# ESM++
[ESM++](https://github.com/Synthyra/ESMplusplus) is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) ([license](https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement)) that allows for batching and standard Huggingface compatibility without requiring the ESM Python package.
The large version corresponds to the 600 million parameter version of ESMC.


## Use with 🤗 transformers
```python
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True)
tokenizer = model.tokenizer

sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, padding=True, return_tensors='pt')

# tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training

output = model(**tokenized) # get all hidden states with output_hidden_states=True
print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 1152)
print(output.loss) # language modeling loss if you passed labels
#print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
```

ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization.

```python
from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification

model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
logits = model(**tokenized).logits
print(logits.shape) # (batch_size, num_labels), (2, 2)
```

ESM++ weights are fp32 by default. You can load them in fp16 or bf16 like this:
```python
import torch
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True, torch_dtype=torch.float16) # or torch.bfloat16
```

## Embed entire datasets with no new code
To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
```python
embeddings = model.embed_dataset(
    sequences=sequences, # list of protein strings
    batch_size=16, # embedding batch size
    max_len=2048, # truncate to max_len
    full_embeddings=True, # return residue-wise embeddings
    full_precision=False, # store as float32
    pooling_type='mean', # use mean pooling if protein-wise embeddings
    num_workers=0, # data loading num workers
    sql=False, # return dictionary of sequences and embeddings
)

_ = model.embed_dataset(
    sequences=sequences, # list of protein strings
    batch_size=16, # embedding batch size
    max_len=2048, # truncate to max_len
    full_embeddings=True, # return residue-wise embeddings
    full_precision=False, # store as float32
    pooling_type='mean', # use mean pooling if protein-wise embeddings
    num_workers=0, # data loading num workers
    sql=True, # store sequences in local SQL database
    sql_db_path='embeddings.db', # path to .db file of choice
)
```

## Returning attention maps
Usually F.scaled_dot_product_attention is used for the attention calculations, which is much faster than native PyTorch. However, it cannot return attention maps.
ESM++ has the option to ```output_attentions```, which will calculate attention manually. This is much slower, so do not use unless you need the attention maps.

```python
output = model(**tokenized, output_attentions=True)
att = output.attentions
len(att) # 33, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each
```

## Comparison across floating-point precision and implementations
We measured the difference of the last hidden states of the fp32 weights vs. fp16 or bf16. We find that the fp16 is closer to the fp32 outputs, so we recommend loading in fp16.
Please note that the ESM package also loads ESMC in fp32 but casts to bf16 by default, which has its share of advantages and disadvantages in inference / training - so load whichever you like for half precision.

Average MSE for FP16: 0.00000003

Average MSE for BF16: 0.00000122

We also measured the difference between the outputs of ESM++ vs. ESMC (both in bfloat16) on 1000 random sequences to ensure compliance with the ESM package.

Average MSE of last hidden state: 2.46e-09

You can load the weights from the ESM package instead of transformers by replacing .from_pretrained(...) to .from_pretrained_esm('esmc_600m')

## Model probes
We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. ESMC (and thus ESM++) perform very well.

The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/uRAHYQcwkbgajylTIFbUb.png)

## Inference speeds
We look at various ESM models and their throughput on an H100. Adding efficient batching between ESMC and ESM++ significantly improves the throughput, although ESM++ is also faster than ESMC for batch size one. ESM++ small is even faster than ESM2-35M with long sequences! The most gains will be seen with PyTorch > 2.5 on linux machines.
![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/Lu6nWB9Fc-7YTql3Z1hVB.png)

### Citation
If you use any of this implementation or work please cite it (as well as the ESMC preprint).

```
@misc {ESMPlusPlus,
	author       = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
	title        = { ESMPlusPlus },
	year         = 2024,
	url          = { https://huggingface.co/Synthyra/ESMplusplus_small },
	doi          = { 10.57967/hf/3726 },
	publisher    = { Hugging Face }
}
```