basilboy commited on
Commit
186e4b6
·
verified ·
1 Parent(s): da0fe61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -24
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from utils import validate_sequence, predict
3
  from model import models
4
  import pandas as pd
5
  import matplotlib.pyplot as plt
@@ -56,29 +56,6 @@ def main():
56
  st.write("## Graphs")
57
  plot_prediction_graphs(all_data)
58
 
59
- def plot_prediction_graphs(data):
60
- # Create a color palette that is consistent across graphs
61
- unique_sequences = sorted(set(seq for seq in data))
62
- palette = sns.color_palette("hsv", len(unique_sequences))
63
- color_dict = {seq: color for seq, color in zip(unique_sequences, palette)}
64
-
65
- for model_name in models.keys():
66
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
67
- for prediction_val in [0, 1]:
68
- ax = ax1 if prediction_val == 0 else ax2
69
- filtered_data = {seq: values[model_name] for seq, values in data.items() if values[model_name][0] == prediction_val}
70
- # Sorting sequences based on confidence, descending
71
- sorted_sequences = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
72
- sequences = [x[0] for x in sorted_sequences]
73
- conf_values = [x[1][1] for x in sorted_sequences]
74
- colors = [color_dict[seq] for seq in sequences]
75
- sns.barplot(x=sequences, y=conf_values, palette=colors, ax=ax)
76
- ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
77
- ax.set_xlabel('Sequences')
78
- ax.set_ylabel('Confidence')
79
- ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
80
-
81
- st.pyplot(fig) # Display the plot with two subplots below the results table
82
 
83
  if __name__ == "__main__":
84
  main()
 
1
  import streamlit as st
2
+ from utils import validate_sequence, predict, plot_prediction_graphs
3
  from model import models
4
  import pandas as pd
5
  import matplotlib.pyplot as plt
 
56
  st.write("## Graphs")
57
  plot_prediction_graphs(all_data)
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  if __name__ == "__main__":
61
  main()