|
from transformers import AutoTokenizer |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
def validate_sequence(sequence): |
|
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") |
|
return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200 |
|
|
|
def load_model(model_name): |
|
|
|
model = torch.load(f'{model_name}_model.pth', map_location=torch.device('cpu')) |
|
model.eval() |
|
return model |
|
|
|
|
|
def predict(model, sequence): |
|
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D') |
|
tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True) |
|
output = model(**tokenized_input) |
|
probabilities = F.softmax(output.logits, dim=-1) |
|
predicted_label = torch.argmax(probabilities, dim=-1) |
|
confidence = probabilities.max().item() * 0.85 |
|
return predicted_label.item(), confidence |