geekyrakshit commited on
Commit
968f4bc
·
1 Parent(s): 0cde3e9

add: integration of training with app

Browse files
app.py CHANGED
@@ -13,6 +13,13 @@ evaluation_page = st.Page(
13
  title="Evaluation",
14
  icon=":material/monitoring:",
15
  )
16
- page_navigation = st.navigation([intro_page, chat_page, evaluation_page])
 
 
 
 
 
 
 
17
  st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
18
  page_navigation.run()
 
13
  title="Evaluation",
14
  icon=":material/monitoring:",
15
  )
16
+ train_classifier_page = st.Page(
17
+ "application_pages/train_classifier.py",
18
+ title="Train Classifier",
19
+ icon=":material/fitness_center:",
20
+ )
21
+ page_navigation = st.navigation(
22
+ [intro_page, chat_page, evaluation_page, train_classifier_page]
23
+ )
24
  st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
25
  page_navigation.run()
application_pages/train_classifier.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from dotenv import load_dotenv
3
+
4
+ import wandb
5
+ from guardrails_genie.train_classifier import train_binary_classifier
6
+
7
+
8
+ def initialize_session_state():
9
+ load_dotenv()
10
+ if "dataset_name" not in st.session_state:
11
+ st.session_state.dataset_name = None
12
+ if "base_model_name" not in st.session_state:
13
+ st.session_state.base_model_name = None
14
+ if "batch_size" not in st.session_state:
15
+ st.session_state.batch_size = 16
16
+ if "should_start_training" not in st.session_state:
17
+ st.session_state.should_start_training = False
18
+ if "training_output" not in st.session_state:
19
+ st.session_state.training_output = None
20
+
21
+
22
+ initialize_session_state()
23
+ st.title(":material/fitness_center: Train Classifier")
24
+
25
+ dataset_name = st.sidebar.text_input("Dataset Name", value="")
26
+ st.session_state.dataset_name = dataset_name
27
+
28
+ base_model_name = st.sidebar.selectbox(
29
+ "Base Model", options=["distilbert/distilbert-base-uncased", "roberta-base"]
30
+ )
31
+ st.session_state.base_model_name = base_model_name
32
+
33
+ batch_size = st.sidebar.slider(
34
+ "Batch Size", min_value=4, max_value=256, value=16, step=4
35
+ )
36
+ st.session_state.batch_size = batch_size
37
+
38
+ train_button = st.sidebar.button("Train")
39
+ st.session_state.should_start_training = (
40
+ train_button and st.session_state.dataset_name and st.session_state.base_model_name
41
+ )
42
+
43
+ if st.session_state.should_start_training:
44
+ with st.expander("Training", expanded=True):
45
+ st.markdown(
46
+ f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
47
+ )
48
+ training_output = train_binary_classifier(
49
+ project_name="guardrails-genie",
50
+ entity_name="geekyrakshit",
51
+ dataset_repo=st.session_state.dataset_name,
52
+ model_name=st.session_state.base_model_name,
53
+ batch_size=st.session_state.batch_size,
54
+ streamlit_mode=True,
55
+ )
56
+ st.session_state.training_output = training_output
57
+ st.write(training_output)
guardrails_genie/train_classifier.py CHANGED
@@ -1,14 +1,39 @@
 
1
  import evaluate
2
  import numpy as np
3
- import wandb
4
  from datasets import load_dataset
5
  from transformers import (
6
  AutoModelForSequenceClassification,
7
  AutoTokenizer,
8
  DataCollatorWithPadding,
9
  Trainer,
 
10
  TrainingArguments,
11
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  def train_binary_classifier(
@@ -20,6 +45,7 @@ def train_binary_classifier(
20
  batch_size: int = 16,
21
  num_epochs: int = 2,
22
  weight_decay: float = 0.01,
 
23
  ):
24
  wandb.init(project=project_name, entity=entity_name)
25
  dataset = load_dataset(dataset_repo)
@@ -69,5 +95,8 @@ def train_binary_classifier(
69
  processing_class=tokenizer,
70
  data_collator=data_collator,
71
  compute_metrics=compute_metrics,
 
72
  )
73
- trainer.train()
 
 
 
1
+
2
  import evaluate
3
  import numpy as np
4
+ import streamlit as st
5
  from datasets import load_dataset
6
  from transformers import (
7
  AutoModelForSequenceClassification,
8
  AutoTokenizer,
9
  DataCollatorWithPadding,
10
  Trainer,
11
+ TrainerCallback,
12
  TrainingArguments,
13
  )
14
+ from transformers.trainer_callback import TrainerControl, TrainerState
15
+
16
+ import wandb
17
+
18
+
19
+ class StreamlitProgressbarCallback(TrainerCallback):
20
+
21
+ def __init__(self, *args, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+ self.progress_bar = st.progress(0, text="Training")
24
+
25
+ def on_step_begin(
26
+ self,
27
+ args: TrainingArguments,
28
+ state: TrainerState,
29
+ control: TrainerControl,
30
+ **kwargs,
31
+ ):
32
+ super().on_step_begin(args, state, control, **kwargs)
33
+ self.progress_bar.progress(
34
+ (state.global_step * 100 // state.max_steps) + 1,
35
+ text=f"Training {state.global_step} / {state.max_steps}",
36
+ )
37
 
38
 
39
  def train_binary_classifier(
 
45
  batch_size: int = 16,
46
  num_epochs: int = 2,
47
  weight_decay: float = 0.01,
48
+ streamlit_mode: bool = False,
49
  ):
50
  wandb.init(project=project_name, entity=entity_name)
51
  dataset = load_dataset(dataset_repo)
 
95
  processing_class=tokenizer,
96
  data_collator=data_collator,
97
  compute_metrics=compute_metrics,
98
+ callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
99
  )
100
+ training_output = trainer.train()
101
+ wandb.finish()
102
+ return training_output