Spaces:
Running
Running
Commit
·
968f4bc
1
Parent(s):
0cde3e9
add: integration of training with app
Browse files- app.py +8 -1
- application_pages/train_classifier.py +57 -0
- guardrails_genie/train_classifier.py +31 -2
app.py
CHANGED
@@ -13,6 +13,13 @@ evaluation_page = st.Page(
|
|
13 |
title="Evaluation",
|
14 |
icon=":material/monitoring:",
|
15 |
)
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
|
18 |
page_navigation.run()
|
|
|
13 |
title="Evaluation",
|
14 |
icon=":material/monitoring:",
|
15 |
)
|
16 |
+
train_classifier_page = st.Page(
|
17 |
+
"application_pages/train_classifier.py",
|
18 |
+
title="Train Classifier",
|
19 |
+
icon=":material/fitness_center:",
|
20 |
+
)
|
21 |
+
page_navigation = st.navigation(
|
22 |
+
[intro_page, chat_page, evaluation_page, train_classifier_page]
|
23 |
+
)
|
24 |
st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
|
25 |
page_navigation.run()
|
application_pages/train_classifier.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
|
4 |
+
import wandb
|
5 |
+
from guardrails_genie.train_classifier import train_binary_classifier
|
6 |
+
|
7 |
+
|
8 |
+
def initialize_session_state():
|
9 |
+
load_dotenv()
|
10 |
+
if "dataset_name" not in st.session_state:
|
11 |
+
st.session_state.dataset_name = None
|
12 |
+
if "base_model_name" not in st.session_state:
|
13 |
+
st.session_state.base_model_name = None
|
14 |
+
if "batch_size" not in st.session_state:
|
15 |
+
st.session_state.batch_size = 16
|
16 |
+
if "should_start_training" not in st.session_state:
|
17 |
+
st.session_state.should_start_training = False
|
18 |
+
if "training_output" not in st.session_state:
|
19 |
+
st.session_state.training_output = None
|
20 |
+
|
21 |
+
|
22 |
+
initialize_session_state()
|
23 |
+
st.title(":material/fitness_center: Train Classifier")
|
24 |
+
|
25 |
+
dataset_name = st.sidebar.text_input("Dataset Name", value="")
|
26 |
+
st.session_state.dataset_name = dataset_name
|
27 |
+
|
28 |
+
base_model_name = st.sidebar.selectbox(
|
29 |
+
"Base Model", options=["distilbert/distilbert-base-uncased", "roberta-base"]
|
30 |
+
)
|
31 |
+
st.session_state.base_model_name = base_model_name
|
32 |
+
|
33 |
+
batch_size = st.sidebar.slider(
|
34 |
+
"Batch Size", min_value=4, max_value=256, value=16, step=4
|
35 |
+
)
|
36 |
+
st.session_state.batch_size = batch_size
|
37 |
+
|
38 |
+
train_button = st.sidebar.button("Train")
|
39 |
+
st.session_state.should_start_training = (
|
40 |
+
train_button and st.session_state.dataset_name and st.session_state.base_model_name
|
41 |
+
)
|
42 |
+
|
43 |
+
if st.session_state.should_start_training:
|
44 |
+
with st.expander("Training", expanded=True):
|
45 |
+
st.markdown(
|
46 |
+
f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
|
47 |
+
)
|
48 |
+
training_output = train_binary_classifier(
|
49 |
+
project_name="guardrails-genie",
|
50 |
+
entity_name="geekyrakshit",
|
51 |
+
dataset_repo=st.session_state.dataset_name,
|
52 |
+
model_name=st.session_state.base_model_name,
|
53 |
+
batch_size=st.session_state.batch_size,
|
54 |
+
streamlit_mode=True,
|
55 |
+
)
|
56 |
+
st.session_state.training_output = training_output
|
57 |
+
st.write(training_output)
|
guardrails_genie/train_classifier.py
CHANGED
@@ -1,14 +1,39 @@
|
|
|
|
1 |
import evaluate
|
2 |
import numpy as np
|
3 |
-
import
|
4 |
from datasets import load_dataset
|
5 |
from transformers import (
|
6 |
AutoModelForSequenceClassification,
|
7 |
AutoTokenizer,
|
8 |
DataCollatorWithPadding,
|
9 |
Trainer,
|
|
|
10 |
TrainingArguments,
|
11 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
def train_binary_classifier(
|
@@ -20,6 +45,7 @@ def train_binary_classifier(
|
|
20 |
batch_size: int = 16,
|
21 |
num_epochs: int = 2,
|
22 |
weight_decay: float = 0.01,
|
|
|
23 |
):
|
24 |
wandb.init(project=project_name, entity=entity_name)
|
25 |
dataset = load_dataset(dataset_repo)
|
@@ -69,5 +95,8 @@ def train_binary_classifier(
|
|
69 |
processing_class=tokenizer,
|
70 |
data_collator=data_collator,
|
71 |
compute_metrics=compute_metrics,
|
|
|
72 |
)
|
73 |
-
trainer.train()
|
|
|
|
|
|
1 |
+
|
2 |
import evaluate
|
3 |
import numpy as np
|
4 |
+
import streamlit as st
|
5 |
from datasets import load_dataset
|
6 |
from transformers import (
|
7 |
AutoModelForSequenceClassification,
|
8 |
AutoTokenizer,
|
9 |
DataCollatorWithPadding,
|
10 |
Trainer,
|
11 |
+
TrainerCallback,
|
12 |
TrainingArguments,
|
13 |
)
|
14 |
+
from transformers.trainer_callback import TrainerControl, TrainerState
|
15 |
+
|
16 |
+
import wandb
|
17 |
+
|
18 |
+
|
19 |
+
class StreamlitProgressbarCallback(TrainerCallback):
|
20 |
+
|
21 |
+
def __init__(self, *args, **kwargs):
|
22 |
+
super().__init__(*args, **kwargs)
|
23 |
+
self.progress_bar = st.progress(0, text="Training")
|
24 |
+
|
25 |
+
def on_step_begin(
|
26 |
+
self,
|
27 |
+
args: TrainingArguments,
|
28 |
+
state: TrainerState,
|
29 |
+
control: TrainerControl,
|
30 |
+
**kwargs,
|
31 |
+
):
|
32 |
+
super().on_step_begin(args, state, control, **kwargs)
|
33 |
+
self.progress_bar.progress(
|
34 |
+
(state.global_step * 100 // state.max_steps) + 1,
|
35 |
+
text=f"Training {state.global_step} / {state.max_steps}",
|
36 |
+
)
|
37 |
|
38 |
|
39 |
def train_binary_classifier(
|
|
|
45 |
batch_size: int = 16,
|
46 |
num_epochs: int = 2,
|
47 |
weight_decay: float = 0.01,
|
48 |
+
streamlit_mode: bool = False,
|
49 |
):
|
50 |
wandb.init(project=project_name, entity=entity_name)
|
51 |
dataset = load_dataset(dataset_repo)
|
|
|
95 |
processing_class=tokenizer,
|
96 |
data_collator=data_collator,
|
97 |
compute_metrics=compute_metrics,
|
98 |
+
callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
|
99 |
)
|
100 |
+
training_output = trainer.train()
|
101 |
+
wandb.finish()
|
102 |
+
return training_output
|