Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from matplotlib.colors import ListedColormap
|
4 |
+
from itertools import combinations
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
plt.rcParams['figure.dpi'] = 100
|
8 |
+
|
9 |
+
from sklearn.datasets import load_iris
|
10 |
+
from sklearn.ensemble import (
|
11 |
+
RandomForestClassifier,
|
12 |
+
ExtraTreesClassifier,
|
13 |
+
AdaBoostClassifier,
|
14 |
+
)
|
15 |
+
from sklearn.tree import DecisionTreeClassifier
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
|
19 |
+
# ========================================
|
20 |
+
|
21 |
+
C1, C2, C3 = '#ff0000', '#ffff00', '#0000ff'
|
22 |
+
CMAP = ListedColormap([C1, C2, C3])
|
23 |
+
GRANULARITY = 0.01
|
24 |
+
SEED = 1
|
25 |
+
N_ESTIMATORS = 30
|
26 |
+
|
27 |
+
FEATURES = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
|
28 |
+
LABELS = ["Setosa", "Versicolour", "Virginica"]
|
29 |
+
MODEL_NAMES = ['DecisionTreeClassifier', 'RandomForestClassifier', 'ExtraTreesClassifier', 'AdaBoostClassifier']
|
30 |
+
|
31 |
+
iris = load_iris()
|
32 |
+
|
33 |
+
MODELS = [
|
34 |
+
DecisionTreeClassifier(max_depth=None),
|
35 |
+
RandomForestClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
|
36 |
+
ExtraTreesClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
|
37 |
+
AdaBoostClassifier(DecisionTreeClassifier(max_depth=3), n_estimators=N_ESTIMATORS)
|
38 |
+
]
|
39 |
+
|
40 |
+
# ========================================
|
41 |
+
|
42 |
+
def create_plot(feature_string, n_estimators, model_idx):
|
43 |
+
np.random.seed(SEED)
|
44 |
+
|
45 |
+
feature_list = feature_string.split(',')
|
46 |
+
feature_list = [s.strip() for s in feature_list]
|
47 |
+
idx_x = FEATURES.index(feature_list[0])
|
48 |
+
idx_y = FEATURES.index(feature_list[1])
|
49 |
+
|
50 |
+
X = iris.data[:, [idx_x, idx_y]]
|
51 |
+
y = iris.target
|
52 |
+
|
53 |
+
rnd_idx = np.random.permutation(X.shape[0])
|
54 |
+
X = X[rnd_idx]
|
55 |
+
y = y[rnd_idx]
|
56 |
+
|
57 |
+
X = (X - X.mean(0)) / X.std(0)
|
58 |
+
|
59 |
+
model_name = MODEL_NAMES[model_idx]
|
60 |
+
model = MODELS[model_idx]
|
61 |
+
|
62 |
+
model.fit(X, y)
|
63 |
+
score = round(model.score(X, y), 3)
|
64 |
+
|
65 |
+
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
|
66 |
+
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
|
67 |
+
xrange = np.arange(x_min, x_max, 0.1)
|
68 |
+
yrange = np.arange(y_min, y_max, 0.1)
|
69 |
+
xx, yy = np.meshgrid(xrange, yrange)
|
70 |
+
|
71 |
+
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
|
72 |
+
Z = Z.reshape(xx.shape)
|
73 |
+
|
74 |
+
fig = plt.figure()
|
75 |
+
ax = fig.add_subplot(111)
|
76 |
+
|
77 |
+
ax.contourf(xx, yy, Z, cmap=CMAP, alpha=0.65)
|
78 |
+
|
79 |
+
for i, label in enumerate(LABELS):
|
80 |
+
X_label = X[y==i,:]
|
81 |
+
y_label = y[y==i]
|
82 |
+
ax.scatter(X_label[:, 0], X_label[:, 1], c=[[C1], [C2], [C3]][i]*len(y_label), edgecolor='k', s=40, label=label)
|
83 |
+
|
84 |
+
ax.set_xlabel(feature_list[0]); ax.set_ylabel(feature_list[1])
|
85 |
+
ax.legend()
|
86 |
+
ax.set_title(f'{model_name} | Score: {score}')
|
87 |
+
|
88 |
+
return fig
|
89 |
+
|
90 |
+
def iter_grid(n_rows, n_cols):
|
91 |
+
for _ in range(n_rows):
|
92 |
+
with gr.Row():
|
93 |
+
for _ in range(n_cols):
|
94 |
+
with gr.Column():
|
95 |
+
yield
|
96 |
+
|
97 |
+
with gr.Blocks() as demo:
|
98 |
+
selections = combinations(FEATURES, 2)
|
99 |
+
selections = [f'{s[0]}, {s[1]}' for s in selections]
|
100 |
+
dd = gr.Dropdown(selections, value=selections[0], interactive=True, label="Input features")
|
101 |
+
slider = gr.Slider(1, 100, value=30, step=1, label='n_estimators')
|
102 |
+
|
103 |
+
counter = 0
|
104 |
+
for _ in iter_grid(2, 2):
|
105 |
+
if counter >= len(MODELS):
|
106 |
+
break
|
107 |
+
|
108 |
+
plot = gr.Plot(label=f'{MODEL_NAMES[counter]}')
|
109 |
+
fn = partial(create_plot, model_idx=counter)
|
110 |
+
|
111 |
+
dd.change(fn, inputs=[dd, slider], outputs=[plot])
|
112 |
+
slider.change(fn, inputs=[dd, slider], outputs=[plot])
|
113 |
+
demo.load(fn, inputs=[dd, slider], outputs=[plot])
|
114 |
+
|
115 |
+
counter += 1
|
116 |
+
|
117 |
+
demo.launch(share=True, debug=True)
|