Spaces:
Sleeping
Sleeping
import pickle | |
import pandas as pd | |
import streamlit as st | |
from datasets import Dataset | |
import model | |
from utils import check_columns, count_labels, get_download_link | |
# Main function to run the Streamlit app | |
def main(): | |
# Set app title | |
st.title("Few Shot Learning Demo using SetFit") | |
st.write("Prepare CSV file with text and label header, here is the sample file") | |
df = pd.read_csv("data/sample.csv") | |
# Display a link to download the file | |
st.markdown(get_download_link(df), unsafe_allow_html=True) | |
# Display the session ID | |
# st.write(f"Session ID: {st.session_state.key}") | |
session_id = st.session_state.key | |
# Create file uploader | |
uploaded_file = st.file_uploader("Choose a CSV file to upload", type="csv") | |
# Check if file was uploaded | |
if uploaded_file is not None: | |
# Read CSV file into pandas DataFrame | |
df = pd.read_csv(uploaded_file) | |
# Check if DataFrame has expected columns | |
if check_columns(df): | |
# Display DataFrame as a table | |
st.write(df) | |
# Calculate the number of instances of each label class | |
label_counts = count_labels(df) | |
st.write(f"Number of instances of each label class: {label_counts}") | |
labels = set(df["label"].tolist()) | |
label_map = {label: idx for idx, label in enumerate(labels)} | |
df["label"] = df["label"].map(label_map) | |
dataset = Dataset.from_pandas(df) | |
model_name = st.text_input("Input the model name") | |
pretrained_model_options = ["all-MiniLM-L6-v2", "paraphrase-MiniLM-L3-v2"] | |
pretrained_model = st.selectbox( | |
"Select a pretrained model", options=pretrained_model_options | |
) | |
# Add Train button | |
if st.button("Train"): | |
# Train the model | |
with st.spinner("Training model..."): | |
model_path = model.run_setfit_training( | |
session_id, | |
pretrained_model, | |
model_name, | |
dataset, | |
1, | |
10, | |
) | |
st.write(f"Model checkpoint saved {model_path.split('/')[-1]}") | |
label_map = {v: k for k, v in label_map.items()} | |
with open(f"{model_path}/label.pkl", "wb") as f: | |
pickle.dump(label_map, f) | |
st.write("Training Finished") | |
st.write("Go to Validation Page") | |
else: | |
st.error("File must have 'text' and 'label' columns.") | |