lhallee commited on
Commit
f456813
·
verified ·
1 Parent(s): 9117f48

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +54 -9
README.md CHANGED
@@ -3,22 +3,25 @@ library_name: transformers
3
  tags: []
4
  ---
5
 
6
- ESM++ is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) that allows for batching and standard Huggingface compatibility without requiring the ESM package.
 
 
7
 
8
- Use with transformers
 
9
  ```python
10
  from transformers import AutoModelForMaskedLM #AutoModel also works
11
- model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
12
  tokenizer = model.tokenizer
13
 
14
  sequences = ['MPRTEIN', 'MSEQWENCE']
15
- tokenized = tokenizer(sequences, return_tensors='pt')
16
 
17
  # tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training
18
 
19
  output = model(**tokenized) # get all hidden states with output_hidden_states=True
20
  print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
21
- print(output.last_hidden_state) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960)
22
  print(output.loss) # language modeling loss if you passed labels
23
  #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
24
  ```
@@ -28,13 +31,55 @@ ESM++ also supports sequence and token level classification tasks like ESM2. Sim
28
  ```python
29
  from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
30
 
31
- model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
32
- logits = model(**tokenized)
33
  print(logits.shape) # (batch_size, num_labels), (2, 2)
34
  ```
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
37
 
38
- Measured difference between this implementation and version loaded with ESM package (1000 random sequences)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- Average MSE: 2.4649680199217982e-09
 
 
3
  tags: []
4
  ---
5
 
6
+ # ESM++
7
+ ESM++ is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) ([license](https://www.evolutionaryscale.ai/policies/cambrian-open-license-agreement)) that allows for batching and standard Huggingface compatibility without requiring the ESM Python package.
8
+ The large version corresponds to the 600 million parameter version of ESMC.
9
 
10
+
11
+ ## Use with 🤗 transformers
12
  ```python
13
  from transformers import AutoModelForMaskedLM #AutoModel also works
14
+ model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True)
15
  tokenizer = model.tokenizer
16
 
17
  sequences = ['MPRTEIN', 'MSEQWENCE']
18
+ tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
19
 
20
  # tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training
21
 
22
  output = model(**tokenized) # get all hidden states with output_hidden_states=True
23
  print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
24
+ print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 1152)
25
  print(output.loss) # language modeling loss if you passed labels
26
  #print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
27
  ```
 
31
  ```python
32
  from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
33
 
34
+ model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
35
+ logits = model(**tokenized).logits
36
  print(logits.shape) # (batch_size, num_labels), (2, 2)
37
  ```
38
 
39
+ ESM++ weights are fp32 by default. You can load them in fp16 or bf16 like this:
40
+ ```python
41
+ import torch
42
+ model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True, torch_dtype=torch.float16) # or torch.bfloat16
43
+ ```
44
+
45
+ ### Comparison across floating-point precision and implementations
46
+ We measured the difference of the last hidden states of the fp32 weights vs. fp16 or bf16. We find that the fp16 is closer to the fp32 outputs, so we recommend loading in fp16.
47
+ Please note that the ESM package also loads ESMC in fp32 but casts to bf16 by default, which has its share of advantages and disadvantages in inference / training - so load whichever you like for half precision.
48
+
49
+ Average MSE for FP16: 0.00000003
50
+
51
+ Average MSE for BF16: 0.00000122
52
+
53
+ We also measured the difference between the outputs of ESM++ vs. ESMC (both in bfloat16) on 1000 random sequences to ensure compliance with the ESM package.
54
 
55
+ Average MSE of last hidden state: 2.46e-09
56
 
57
+ You can load the weights from the ESM package instead of transformers by replacing .from_pretrained(...) to .from_pretrained_esm('esmc_600m')
58
+
59
+ ## Model probes
60
+ We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to access the intrinsic correlation between pooled hidden states and valuable properties. ESMC (and thus ESM++) perform very well.
61
+
62
+ The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
63
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/88EweoycOIhJBk2RvGCKn.png)
64
+
65
+ ## Inference speeds
66
+ We look at various ESM models and their throughput on an H100. Adding efficient batching between ESMC and ESM++ significantly improves the throughput. ESM++ small is even faster than ESM2-35M with long sequences!
67
+ The most gains will be seen with PyTorch > 2.5 on linux machines.
68
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/NNmbupdIxuieuyjc8vMuj.png)
69
+
70
+ ### Citation
71
+ If you use any of this implementation or work please cite it (as well as the ESMC preprint). Bibtex for both coming soon.
72
+
73
+ ```
74
+ @misc {ESMPlusPlus,
75
+ author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
76
+ title = { ESMPlusPlus },
77
+ year = 2024,
78
+ url = { https://huggingface.co/Synthyra/ESMplusplus_small },
79
+ doi = { 10.57967/hf/3725 },
80
+ publisher = { Hugging Face }
81
+ }
82
+ ```
83
 
84
+ ### Note on the name
85
+ The original thought was ESMC++ but anything with C would technically go against the ESM license agreement - so ESM++. Open to suggestions!