frenkt's picture
Upload 2 files
e881c02
raw
history blame
3.02 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import BayesianRidge
SEED = 1234
ORDER = 3
MAX_SAMPLES = 100
def sin_wave(x: np.array):
"""Sinusoidal wave function"""
return np.sin(2 * np.pi * x)
def generate_train_data(n_samples: int):
"""Generates sinuosidal data with noise"""
rng = np.random.RandomState(SEED)
x_train = rng.uniform(0.0, 1.0, n_samples)
y_train = sin_wave(x_train) + rng.normal(scale=0.1, size=n_samples)
X_train = np.vander(x_train, ORDER + 1, increasing=True)
return x_train, X_train, y_train
def get_app_fn():
"""Returns the demo function with pre-generated data and model"""
x_test = np.linspace(0.0, 1.0, 100)
X_test = np.vander(x_test, ORDER + 1, increasing=True)
y_test = sin_wave(x_test)
reg = BayesianRidge(tol=1e-6, fit_intercept=False, compute_score=True)
x_train_full, X_train_full, y_train_full = generate_train_data(MAX_SAMPLES)
def app_fn(n_samples: int, alpha_init: float, lambda_init: float):
"""Train a Bayesian Ridge regression model and plot the predicted points"""
rng = np.random.RandomState(SEED)
subset_idx = rng.randint(0, MAX_SAMPLES, n_samples)
x_train, X_train, y_train = (
x_train_full[subset_idx],
X_train_full[subset_idx],
y_train_full[subset_idx],
)
reg.set_params(alpha_init=alpha_init, lambda_init=lambda_init)
reg.fit(X_train, y_train)
ymean, ystd = reg.predict(X_test, return_std=True)
fig, ax = plt.subplots()
ax.plot(x_test, y_test, color="blue", label="sin($2\\pi x$)")
ax.scatter(x_train, y_train, s=50, alpha=0.5, label="observation")
ax.plot(x_test, ymean, color="red", label="predict mean")
ax.fill_between(
x_test,
ymean - ystd,
ymean + ystd,
color="pink",
alpha=0.5,
label="predict std",
)
ax.set_ylim(-1.3, 1.3)
ax.legend()
text = "$\\alpha={:.1f}$\n$\\lambda={:.3f}$\n$L={:.1f}$".format(
reg.alpha_, reg.lambda_, reg.scores_[-1]
)
ax.text(0.05, -1.0, text, fontsize=12)
return fig
return app_fn
title = "Bayesian Ridge Regression"
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
n_samples_input = gr.Slider(minimum=5, maximum=100, value=25, step=1, label="#observations")
alpha_input = gr.Slider(minimum=0.001, maximum=5, value=1.9, step=0.01, label="alpha_init")
lambda_input = gr.Slider(minimum=0.001, maximum=5, value=1., step=0.01, label="lambda_init")
outputs = gr.Plot(label="Output")
inputs = [n_samples_input, alpha_input, lambda_input]
app_fn = get_app_fn()
n_samples_input.change(fn=app_fn, inputs=inputs, outputs=outputs)
alpha_input.change(fn=app_fn, inputs=inputs, outputs=outputs)
lambda_input.change(fn=app_fn, inputs=inputs, outputs=outputs)
demo.launch()