huabdul's picture
Update app.py
cf7215e
raw
history blame
4.55 kB
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
plt.rcParams['figure.dpi'] = 100
plt.style.use('ggplot')
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_moons, make_circles, make_classification
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import make_pipeline
import gradio as gr
#==========================================================================
C1, C2 = '#ff0000', '#0000ff'
CMAP = ListedColormap([C1, C2])
GRANULARITY = 0.01
MARGIN = 0.5
N_SAMPLES = 150
#==========================================================================
def get_decision_surface(X, model):
x_min, x_max = X[:, 0].min() - MARGIN, X[:, 0].max() + MARGIN
y_min, y_max = X[:, 1].min() - MARGIN, X[:, 1].max() + MARGIN
xrange = np.arange(x_min, x_max, GRANULARITY)
yrange = np.arange(y_min, y_max, GRANULARITY)
xx, yy = np.meshgrid(xrange, yrange)
Z = model.predict_proba(np.column_stack([xx.ravel(), yy.ravel()]))[:, 1]
Z = Z.reshape(xx.shape)
return xx, yy, Z
def create_plot(alpha, seed):
X, y = make_classification(
n_samples=N_SAMPLES, n_features=2, n_redundant=0, n_informative=2, random_state=seed, n_clusters_per_class=1
)
rng = np.random.RandomState(seed)
X += 2 * rng.uniform(size=X.shape)
linearly_separable = (X, y)
datasets = [
make_moons(n_samples=N_SAMPLES, noise=0.3, random_state=seed),
make_circles(n_samples=N_SAMPLES, noise=0.2, factor=0.5, random_state=seed),
linearly_separable
]
model = make_pipeline(
StandardScaler(),
MLPClassifier(
solver="lbfgs",
alpha=alpha,
random_state=seed,
max_iter=2000,
early_stopping=True,
hidden_layer_sizes=[10, 10]))
fig = plt.figure(figsize=(7, 7))
for i, ds in enumerate(datasets):
X, y = ds
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=SEED)
model.fit(X_train, y_train)
ax = fig.add_subplot(3, 2, 2*i+1)
ax.set_xticks(()); ax.set_yticks(())
ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=CMAP, edgecolor='k', s=40)
ax.set_xlim((X[:, 0].min() - MARGIN, X[:, 0].max() + MARGIN))
ax.set_ylim((X[:, 1].min() - MARGIN, X[:, 1].max() + MARGIN))
if i == 0: ax.set_title('Training Data')
ax = fig.add_subplot(3, 2, 2*i+2)
ax.set_xticks(()); ax.set_yticks(())
xx, yy, Z = get_decision_surface(X, model)
ax.contourf(xx, yy, Z, cmap=plt.cm.RdBu, alpha=0.65)
ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=CMAP, edgecolor='k', s=40, marker="X")
if i == 0: ax.set_title('Testing Data')
fig.set_tight_layout(True)
return fig
info = '''
# Effect of Regularization Parameter of Multilayer Perceptron
This example demonstrates the effect of varying the regularization parameter (alpha) of a multilayer perceptron on the binary classification of toy datasets, as represented by the decision surface of the classifier.
Higher values of alpha encourages smaller weights, thus making the model less prone to overfitting, while lower values may help against underfitting. Use the slider below to control the amount of regularization and observe how the decision surface changes with higher values.
The color of the decision surface represents the probability of observing the class. Darker colors mean higher probability and thus higher confidence, and vice versa.
Created by [@huabdul](https://huggingface.co/huabdul) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/neural_networks/plot_mlp_alpha.html).
'''
with gr.Blocks(analytics_enabled=False) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(info)
s_alpha = gr.Slider(0, 4, value=0.1, step=0.05, label='Alpha (regularization parameter)')
s_seed = gr.Slider(1, 5000, value=1, step=1, label='Random seed')
with gr.Column():
plot = gr.Plot(show_label=False)
s_alpha.change(create_plot, inputs=[s_alpha, s_seed], outputs=[plot])
s_seed.change(create_plot, inputs=[s_alpha, s_seed], outputs=[plot])
demo.load(create_plot, inputs=[s_alpha, s_seed], outputs=[plot])
demo.launch()
#==========================================================================