test3 / app.py
basilboy's picture
Update app.py
44d5af7 verified
import streamlit as st
from utils import validate_sequence, predict, plot_prediction_graphs
from model import models
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def main():
st.set_page_config(layout="wide") # Keep the wide layout for overall flexibility
st.title("AA Property Inference Demo", anchor=None)
# Instructional text below title
st.markdown("""
<style>
.reportview-container {
font-family: 'Courier New', monospace;
}
</style>
<p style='font-size:16px;'><span style='font-size:24px;'>&larr;</span> Don't know where to start? Open tab to input a sequence.</p>
""", unsafe_allow_html=True)
# Input section in the sidebar
sequence = st.sidebar.text_input("Enter your amino acid sequence:")
uploaded_file = st.sidebar.file_uploader("Or upload a CSV file with amino acid sequences", type="csv")
analyze_pressed = st.sidebar.button("Analyze Sequence")
show_graphs = st.sidebar.checkbox("Show Prediction Graphs")
sequences = [sequence] if sequence else []
if uploaded_file:
df = pd.read_csv(uploaded_file)
sequences.extend(df['sequence'].tolist())
names = df['name'].tolist() # Store names from the CSV file
else:
names = [f"Seq {i+1}" for i in range(len(sequences))] # Default names if no file
results = []
all_data = {}
if analyze_pressed:
for name, seq in zip(names, sequences):
if validate_sequence(seq):
model_results = {}
graph_data = {}
for model_name, model in models.items():
prediction, confidence = predict(model, seq)
model_results[f"{model_name}_prediction"] = prediction
model_results[f"{model_name}_confidence"] = round(confidence, 3)
graph_data[model_name] = (prediction, confidence)
results.append({"Name": name, "Sequence": seq, **model_results})
all_data[name] = graph_data # Use name as key
else:
st.sidebar.error(f"Invalid sequence for {name}: {seq}")
if results:
results_df = pd.DataFrame(results)
st.write("### Results")
st.dataframe(results_df.style.format(precision=3), width=None, height=None)
if show_graphs and all_data:
st.write("## Graphs")
plot_prediction_graphs(all_data,models.keys())
if __name__ == "__main__":
main()