ESMplusplus_large / README.md
lhallee's picture
Update README.md
fbb9df7 verified
|
raw
history blame
1.78 kB
metadata
library_name: transformers
tags: []

ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility without requiring the ESM package.

Use with transformers

from transformers import AutoModelForMaskedLM #AutoModel also works
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
tokenizer = model.tokenizer

sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, return_tensors='pt')

# tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training

output = model(**tokenized) # get all hidden states with output_hidden_states=True
print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
print(output.last_hidden_state) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960)
print(output.loss) # language modeling loss if you passed labels
#print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)

ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization.

from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification

model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
logits = model(**tokenized)
print(logits.shape) # (batch_size, num_labels), (2, 2)

Measured difference between this implementation and version loaded with ESM package (1000 random sequences)

Average MSE: 2.4649680199217982e-09