File size: 2,336 Bytes
23fcfc1
21849ba
49c5855
121b388
6116303
 
49c5855
 
 
 
 
76b0555
 
 
49c5855
 
 
76b0555
49c5855
21849ba
 
 
76b0555
 
 
a6bbca3
 
12565b9
a6bbca3
12565b9
 
 
a6bbca3
12565b9
a6bbca3
 
 
12565b9
 
 
 
 
 
 
a6bbca3
12565b9
a6bbca3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import streamlit as st
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

def validate_sequence(sequence):
    valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY")  # 20 standard amino acids
    return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200

def load_model(model_name):
    # Load the model based on the provided name
    model = torch.load(f'{model_name}_model.pth', map_location=torch.device('cpu'))
    model.eval()
    return model


def predict(model, sequence):
    tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
    tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
    output = model(**tokenized_input)
    probabilities = F.softmax(output.logits, dim=-1)
    predicted_label = torch.argmax(probabilities, dim=-1)
    confidence = probabilities.max().item() * 0.85
    return predicted_label.item(), confidence

def plot_prediction_graphs(data,model_keys):
    # Create a color palette that is consistent across graphs
    unique_names = sorted(data.keys())  # Using names instead of sequences
    palette = sns.color_palette("hsv", len(unique_names))
    color_dict = {name: color for name, color in zip(unique_names, palette)}

    for model_name in model_keys:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
        for prediction_val in [0, 1]:
            ax = ax1 if prediction_val == 0 else ax2
            filtered_data = {name: values[model_name] for name, values in data.items() if values[model_name][0] == prediction_val}
            # Sorting names based on confidence, descending
            sorted_names = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
            names = [x[0] for x in sorted_names]
            conf_values = [x[1][1] for x in sorted_names]
            colors = [color_dict[name] for name in names]
            sns.barplot(x=names, y=conf_values, palette=colors, ax=ax)
            ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
            ax.set_xlabel('Names')
            ax.set_ylabel('Confidence')
            ax.tick_params(axis='x', rotation=45)  # Rotate x labels for better visibility

        st.pyplot(fig)