Update utils.py
Browse files
utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from transformers import AutoTokenizer
|
2 |
import torch
|
|
|
3 |
|
4 |
def validate_sequence(sequence):
|
5 |
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
|
@@ -15,4 +16,7 @@ def predict(model, sequence):
|
|
15 |
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
|
16 |
tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
|
17 |
output = model(**tokenized_input)
|
18 |
-
|
|
|
|
|
|
|
|
1 |
from transformers import AutoTokenizer
|
2 |
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
|
5 |
def validate_sequence(sequence):
|
6 |
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
|
|
|
16 |
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
|
17 |
tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
|
18 |
output = model(**tokenized_input)
|
19 |
+
logits = output.logits # Extract logits
|
20 |
+
probabilities = F.softmax(logits, dim=-1) # Apply softmax to convert logits to probabilities
|
21 |
+
predicted_label = torch.argmax(probabilities, dim=-1) # Get the predicted label
|
22 |
+
return predicted_label.item() # Return the label as a Python integer
|