ZennyKenny commited on
Commit
80a6e34
Β·
verified Β·
1 Parent(s): e252d2c

add dataset selector

Browse files
Files changed (1) hide show
  1. app.py +163 -103
app.py CHANGED
@@ -2,140 +2,200 @@ import gradio as gr
2
  import numpy as np
3
  import matplotlib
4
  import matplotlib.pyplot as plt
5
- from sklearn.datasets import load_iris
 
 
6
  from sklearn.ensemble import GradientBoostingClassifier
7
  from sklearn.model_selection import train_test_split
8
  from sklearn.metrics import accuracy_score, confusion_matrix
9
 
10
- # This line ensures Matplotlib doesn't try to open windows in certain environments:
11
- matplotlib.use('Agg')
12
-
13
- # Load the Iris dataset
14
- iris = load_iris()
15
- X, y = iris.data, iris.target
16
- feature_names = iris.feature_names
17
- class_names = iris.target_names
18
-
19
- # Train/test split
20
- X_train, X_test, y_train, y_test = train_test_split(
21
- X, y, test_size=0.3, random_state=42
22
- )
23
-
24
- def train_and_evaluate(learning_rate, n_estimators, max_depth):
25
- # Train the model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  clf = GradientBoostingClassifier(
27
  learning_rate=learning_rate,
28
- n_estimators=n_estimators,
29
  max_depth=int(max_depth),
30
  random_state=42
31
  )
32
  clf.fit(X_train, y_train)
33
-
34
- # Predict on test set
35
  y_pred = clf.predict(X_test)
36
-
37
- # Calculate accuracy
38
  accuracy = accuracy_score(y_test, y_pred)
39
-
40
- # Calculate confusion matrix
41
  cm = confusion_matrix(y_test, y_pred)
42
 
43
- # Create a single figure with 2 subplots
44
- fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))
45
 
46
- # --- Subplot 1: Feature Importances ---
47
  importances = clf.feature_importances_
48
- axs[0].barh(range(len(feature_names)), importances, color='skyblue')
49
- axs[0].set_yticks(range(len(feature_names)))
50
- axs[0].set_yticklabels(feature_names)
51
  axs[0].set_xlabel("Importance")
52
  axs[0].set_title("Feature Importances")
53
 
54
- # --- Subplot 2: Confusion Matrix Heatmap ---
55
  im = axs[1].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
56
  axs[1].set_title("Confusion Matrix")
57
- # Add colorbar
58
- cbar = fig.colorbar(im, ax=axs[1])
59
- # Tick marks for x/y axes
60
- axs[1].set_xticks(range(len(class_names)))
61
- axs[1].set_yticks(range(len(class_names)))
62
- axs[1].set_xticklabels(class_names, rotation=45, ha="right")
63
- axs[1].set_yticklabels(class_names)
64
- axs[1].set_ylabel('True Label')
65
- axs[1].set_xlabel('Predicted Label')
66
-
67
- # Write the counts in each cell
68
  thresh = cm.max() / 2.0
69
  for i in range(cm.shape[0]):
70
  for j in range(cm.shape[1]):
71
  color = "white" if cm[i, j] > thresh else "black"
72
- axs[1].text(j, i, format(cm[i, j], "d"),
73
- ha="center", va="center", color=color)
74
 
75
  plt.tight_layout()
76
 
77
- # Return textual results + the figure
78
- results_text = f"Accuracy: {accuracy:.3f}"
79
- return results_text, fig
80
-
81
- def predict_species(sepal_length, sepal_width, petal_length, petal_width,
82
- learning_rate, n_estimators, max_depth):
83
- clf = GradientBoostingClassifier(
84
- learning_rate=learning_rate,
85
- n_estimators=n_estimators,
86
- max_depth=int(max_depth),
87
- random_state=42
88
- )
89
- clf.fit(X_train, y_train)
90
- user_sample = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
91
- prediction = clf.predict(user_sample)[0]
92
- return f"Predicted species: {class_names[prediction]}"
 
 
 
 
 
 
 
 
 
93
 
94
  with gr.Blocks() as demo:
95
- with gr.Tab("Train & Evaluate"):
96
- gr.Markdown("## Train a GradientBoostingClassifier on the Iris dataset")
97
-
98
- learning_rate_slider = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
99
- n_estimators_slider = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
100
- max_depth_slider = gr.Slider(1, 10, value=3, step=1, label="max_depth")
101
-
102
- train_button = gr.Button("Train & Evaluate")
103
- output_text = gr.Textbox(label="Results")
104
- output_plot = gr.Plot(label="Feature Importances & Confusion Matrix")
105
-
106
- train_button.click(
107
- fn=train_and_evaluate,
108
- inputs=[learning_rate_slider, n_estimators_slider, max_depth_slider],
109
- outputs=[output_text, output_plot],
110
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- with gr.Tab("Predict"):
113
- gr.Markdown("## Predict Iris Species with GradientBoostingClassifier")
114
-
115
- sepal_length_input = gr.Number(value=5.1, label=feature_names[0])
116
- sepal_width_input = gr.Number(value=3.5, label=feature_names[1])
117
- petal_length_input = gr.Number(value=1.4, label=feature_names[2])
118
- petal_width_input = gr.Number(value=0.2, label=feature_names[3])
119
-
120
- learning_rate_slider2 = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
121
- n_estimators_slider2 = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
122
- max_depth_slider2 = gr.Slider(1, 10, value=3, step=1, label="max_depth")
123
-
124
- predict_button = gr.Button("Predict")
125
- prediction_text = gr.Textbox(label="Prediction")
126
-
127
- predict_button.click(
128
- fn=predict_species,
129
- inputs=[
130
- sepal_length_input,
131
- sepal_width_input,
132
- petal_length_input,
133
- petal_width_input,
134
- learning_rate_slider2,
135
- n_estimators_slider2,
136
- max_depth_slider2,
137
- ],
138
- outputs=prediction_text
139
- )
140
 
141
  demo.launch()
 
2
  import numpy as np
3
  import matplotlib
4
  import matplotlib.pyplot as plt
5
+ import pandas as pd
6
+
7
+ from datasets import load_dataset
8
  from sklearn.ensemble import GradientBoostingClassifier
9
  from sklearn.model_selection import train_test_split
10
  from sklearn.metrics import accuracy_score, confusion_matrix
11
 
12
+ matplotlib.use('Agg') # Avoid issues in some remote environments
13
+
14
+ # Pre-populate a short list of "recommended" Hugging Face datasets
15
+ # (Replace "datasorg/iris" etc. with real dataset IDs you want to showcase)
16
+ SUGGESTED_DATASETS = [
17
+ "datasorg/iris", # hypothetical ID
18
+ "uciml/wine_quality-red", # example from the HF Hub
19
+ "SKIP/ENTER_CUSTOM" # We'll treat this as a "separator" or "prompt" for custom
20
+ ]
21
+
22
+ def load_and_prepare_dataset(dataset_id, label_column, feature_columns):
23
+ """
24
+ Loads a dataset from the Hugging Face Hub,
25
+ converts it to a pandas DataFrame,
26
+ returns X, y as NumPy arrays for modeling.
27
+ """
28
+ # Load only the "train" split for simplicity
29
+ # Many datasets have "train", "test", "validation" splits
30
+ ds = load_dataset(dataset_id, split="train")
31
+
32
+ # Convert to a DataFrame for easy manipulation
33
+ df = pd.DataFrame(ds)
34
+
35
+ # Subset to selected columns
36
+ if label_column not in df.columns:
37
+ raise ValueError(f"Label column '{label_column}' not in dataset columns: {df.columns.to_list()}")
38
+
39
+ for col in feature_columns:
40
+ if col not in df.columns:
41
+ raise ValueError(f"Feature column '{col}' not in dataset columns: {df.columns.to_list()}")
42
+
43
+ # Split into X and y
44
+ X = df[feature_columns].values
45
+ y = df[label_column].values
46
+
47
+ return X, y, df.columns.tolist()
48
+
49
+ def train_model(dataset_id, custom_dataset_id, label_column, feature_columns,
50
+ learning_rate, n_estimators, max_depth, test_size):
51
+ """
52
+ 1. Determine final dataset ID (either from dropdown or custom text).
53
+ 2. Load dataset -> DataFrame -> X, y.
54
+ 3. Train a GradientBoostingClassifier.
55
+ 4. Generate plots & metrics (accuracy and confusion matrix).
56
+ """
57
+
58
+ # Decide which dataset ID to use
59
+ if dataset_id != "SKIP/ENTER_CUSTOM":
60
+ final_id = dataset_id
61
+ else:
62
+ # Use the user-supplied "custom_dataset_id"
63
+ final_id = custom_dataset_id.strip()
64
+
65
+ # Prepare data
66
+ X, y, columns_available = load_and_prepare_dataset(
67
+ final_id,
68
+ label_column,
69
+ feature_columns
70
+ )
71
+
72
+ # Train/test split
73
+ X_train, X_test, y_train, y_test = train_test_split(
74
+ X, y, test_size=test_size, random_state=42
75
+ )
76
+
77
+ # Train model
78
  clf = GradientBoostingClassifier(
79
  learning_rate=learning_rate,
80
+ n_estimators=int(n_estimators),
81
  max_depth=int(max_depth),
82
  random_state=42
83
  )
84
  clf.fit(X_train, y_train)
85
+
86
+ # Evaluate
87
  y_pred = clf.predict(X_test)
 
 
88
  accuracy = accuracy_score(y_test, y_pred)
 
 
89
  cm = confusion_matrix(y_test, y_pred)
90
 
91
+ # Plot figure
92
+ fig, axs = plt.subplots(1, 2, figsize=(10, 4))
93
 
94
+ # Subplot 1: Feature Importances
95
  importances = clf.feature_importances_
96
+ axs[0].barh(range(len(feature_columns)), importances, color='skyblue')
97
+ axs[0].set_yticks(range(len(feature_columns)))
98
+ axs[0].set_yticklabels(feature_columns)
99
  axs[0].set_xlabel("Importance")
100
  axs[0].set_title("Feature Importances")
101
 
102
+ # Subplot 2: Confusion Matrix Heatmap
103
  im = axs[1].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
104
  axs[1].set_title("Confusion Matrix")
105
+ plt.colorbar(im, ax=axs[1])
106
+ # Labeling
107
+ axs[1].set_xlabel("Predicted")
108
+ axs[1].set_ylabel("True")
109
+
110
+ # If you want to annotate each cell:
 
 
 
 
 
111
  thresh = cm.max() / 2.0
112
  for i in range(cm.shape[0]):
113
  for j in range(cm.shape[1]):
114
  color = "white" if cm[i, j] > thresh else "black"
115
+ axs[1].text(j, i, format(cm[i, j], "d"), ha="center", va="center", color=color)
 
116
 
117
  plt.tight_layout()
118
 
119
+ output_text = f"**Dataset used:** {final_id}\n\n"
120
+ output_text += f"**Accuracy:** {accuracy:.3f}\n\n"
121
+ output_text += "**Confusion Matrix** (raw counts above)."
122
+
123
+ return output_text, fig, columns_available
124
+
125
+ def update_columns(dataset_id, custom_dataset_id):
126
+ """
127
+ Callback to dynamically fetch the columns from the dataset
128
+ so the user can pick which columns to use as features/labels.
129
+ """
130
+ if dataset_id != "SKIP/ENTER_CUSTOM":
131
+ final_id = dataset_id
132
+ else:
133
+ final_id = custom_dataset_id.strip()
134
+
135
+ # Try to load the dataset and return columns
136
+ try:
137
+ ds = load_dataset(final_id, split="train")
138
+ df = pd.DataFrame(ds)
139
+ cols = df.columns.tolist()
140
+ # Return as list of selectable options
141
+ return gr.update(choices=cols), gr.update(choices=cols), f"Columns found: {cols}"
142
+ except Exception as e:
143
+ return gr.update(choices=[]), gr.update(choices=[]), f"Error loading {final_id}: {e}"
144
 
145
  with gr.Blocks() as demo:
146
+ gr.Markdown("## Train GradientBoostingClassifier on a Hugging Face dataset of your choice")
147
+
148
+ with gr.Row():
149
+ dataset_dropdown = gr.Dropdown(
150
+ choices=SUGGESTED_DATASETS,
151
+ value=SUGGESTED_DATASETS[0],
152
+ label="Choose a dataset"
 
 
 
 
 
 
 
 
153
  )
154
+ custom_dataset_id = gr.Textbox(label="Or enter HF dataset (user/dataset)", value="",
155
+ placeholder="e.g. 'username/my_custom_dataset'")
156
+
157
+ # Button to load columns from the chosen dataset
158
+ load_cols_btn = gr.Button("Load columns")
159
+ load_cols_info = gr.Markdown()
160
+
161
+ with gr.Row():
162
+ label_col = gr.Dropdown(choices=[], label="Label column (choose 1)")
163
+ feature_cols = gr.CheckboxGroup(choices=[], label="Feature columns (choose 1 or more)")
164
+
165
+ # Once columns are chosen, we can set hyperparams
166
+ learning_rate_slider = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
167
+ n_estimators_slider = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
168
+ max_depth_slider = gr.Slider(1, 10, value=3, step=1, label="max_depth")
169
+ test_size_slider = gr.Slider(0.1, 0.9, value=0.3, step=0.1, label="test_size (fraction)")
170
+
171
+ train_button = gr.Button("Train & Evaluate")
172
+
173
+ output_text = gr.Markdown()
174
+ output_plot = gr.Plot()
175
+ # We might also want to show the columns for reference post-training
176
+ columns_return = gr.Markdown()
177
+
178
+ # When "Load columns" is clicked, we call update_columns to fetch the dataset columns
179
+ load_cols_btn.click(
180
+ fn=update_columns,
181
+ inputs=[dataset_dropdown, custom_dataset_id],
182
+ outputs=[label_col, feature_cols, load_cols_info]
183
+ )
184
 
185
+ # When "Train & Evaluate" is clicked, we train the model
186
+ train_button.click(
187
+ fn=train_model,
188
+ inputs=[
189
+ dataset_dropdown,
190
+ custom_dataset_id,
191
+ label_col,
192
+ feature_cols,
193
+ learning_rate_slider,
194
+ n_estimators_slider,
195
+ max_depth_slider,
196
+ test_size_slider
197
+ ],
198
+ outputs=[output_text, output_plot, columns_return]
199
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  demo.launch()