basilboy commited on
Commit
121b388
·
verified ·
1 Parent(s): 7f07fc3

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +5 -1
utils.py CHANGED
@@ -1,5 +1,6 @@
1
  from transformers import AutoTokenizer
2
  import torch
 
3
 
4
  def validate_sequence(sequence):
5
  valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
@@ -15,4 +16,7 @@ def predict(model, sequence):
15
  tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
16
  tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
17
  output = model(**tokenized_input)
18
- return output.item()
 
 
 
 
1
  from transformers import AutoTokenizer
2
  import torch
3
+ import torch.nn.functional as F
4
 
5
  def validate_sequence(sequence):
6
  valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
 
16
  tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
17
  tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
18
  output = model(**tokenized_input)
19
+ logits = output.logits # Extract logits
20
+ probabilities = F.softmax(logits, dim=-1) # Apply softmax to convert logits to probabilities
21
+ predicted_label = torch.argmax(probabilities, dim=-1) # Get the predicted label
22
+ return predicted_label.item() # Return the label as a Python integer