few-shot-demo / training.py
spdin
add csv file
f0ad92c
import pickle
import pandas as pd
import streamlit as st
from datasets import Dataset
import model
from utils import check_columns, count_labels, get_download_link
# Main function to run the Streamlit app
def main():
# Set app title
st.title("Few Shot Learning Demo using SetFit")
st.write("Prepare CSV file with text and label header, here is the sample file")
df = pd.read_csv("data/sample.csv")
# Display a link to download the file
st.markdown(get_download_link(df), unsafe_allow_html=True)
# Display the session ID
# st.write(f"Session ID: {st.session_state.key}")
session_id = st.session_state.key
# Create file uploader
uploaded_file = st.file_uploader("Choose a CSV file to upload", type="csv")
# Check if file was uploaded
if uploaded_file is not None:
# Read CSV file into pandas DataFrame
df = pd.read_csv(uploaded_file)
# Check if DataFrame has expected columns
if check_columns(df):
# Display DataFrame as a table
st.write(df)
# Calculate the number of instances of each label class
label_counts = count_labels(df)
st.write(f"Number of instances of each label class: {label_counts}")
labels = set(df["label"].tolist())
label_map = {label: idx for idx, label in enumerate(labels)}
df["label"] = df["label"].map(label_map)
dataset = Dataset.from_pandas(df)
model_name = st.text_input("Input the model name")
pretrained_model_options = ["all-MiniLM-L6-v2", "paraphrase-MiniLM-L3-v2"]
pretrained_model = st.selectbox(
"Select a pretrained model", options=pretrained_model_options
)
# Add Train button
if st.button("Train"):
# Train the model
with st.spinner("Training model..."):
model_path = model.run_setfit_training(
session_id,
pretrained_model,
model_name,
dataset,
1,
10,
)
st.write(f"Model checkpoint saved {model_path.split('/')[-1]}")
label_map = {v: k for k, v in label_map.items()}
with open(f"{model_path}/label.pkl", "wb") as f:
pickle.dump(label_map, f)
st.write("Training Finished")
st.write("Go to Validation Page")
else:
st.error("File must have 'text' and 'label' columns.")