File size: 922 Bytes
21849ba 49c5855 121b388 49c5855 76b0555 49c5855 76b0555 49c5855 21849ba 76b0555 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F
def validate_sequence(sequence):
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200
def load_model(model_name):
# Load the model based on the provided name
model = torch.load(f'{model_name}_model.pth', map_location=torch.device('cpu'))
model.eval()
return model
def predict(model, sequence):
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
output = model(**tokenized_input)
probabilities = F.softmax(output.logits, dim=-1)
predicted_label = torch.argmax(probabilities, dim=-1)
confidence = probabilities.max().item() * 0.85
return predicted_label.item(), confidence |