import pickle import pandas as pd import streamlit as st from datasets import Dataset import model from utils import check_columns, count_labels, get_download_link # Main function to run the Streamlit app def main(): # Set app title st.title("Few Shot Learning Demo using SetFit") st.write("Prepare CSV file with text and label header, here is the sample file") df = pd.read_csv("data/sample.csv") # Display a link to download the file st.markdown(get_download_link(df), unsafe_allow_html=True) # Display the session ID # st.write(f"Session ID: {st.session_state.key}") session_id = st.session_state.key # Create file uploader uploaded_file = st.file_uploader("Choose a CSV file to upload", type="csv") # Check if file was uploaded if uploaded_file is not None: # Read CSV file into pandas DataFrame df = pd.read_csv(uploaded_file) # Check if DataFrame has expected columns if check_columns(df): # Display DataFrame as a table st.write(df) # Calculate the number of instances of each label class label_counts = count_labels(df) st.write(f"Number of instances of each label class: {label_counts}") labels = set(df["label"].tolist()) label_map = {label: idx for idx, label in enumerate(labels)} df["label"] = df["label"].map(label_map) dataset = Dataset.from_pandas(df) model_name = st.text_input("Input the model name") pretrained_model_options = ["all-MiniLM-L6-v2", "paraphrase-MiniLM-L3-v2"] pretrained_model = st.selectbox( "Select a pretrained model", options=pretrained_model_options ) # Add Train button if st.button("Train"): # Train the model with st.spinner("Training model..."): model_path = model.run_setfit_training( session_id, pretrained_model, model_name, dataset, 1, 10, ) st.write(f"Model checkpoint saved {model_path.split('/')[-1]}") label_map = {v: k for k, v in label_map.items()} with open(f"{model_path}/label.pkl", "wb") as f: pickle.dump(label_map, f) st.write("Training Finished") st.write("Go to Validation Page") else: st.error("File must have 'text' and 'label' columns.")