import gradio as gr import numpy as np import matplotlib.pyplot as plt from sklearn import linear_model def plot(seed, num_points): # Error handling of non-numeric seeds if seed and not seed.isnumeric(): raise gr.Error("Invalid seed") # Setting the seed if seed: seed = int(seed) np.random.seed(seed) num_points = int(num_points) #Ensuring the number of points is even if num_points%2 != 0: num_points +=1 half_num_points = int(num_points/2) X = np.r_[np.random.randn(half_num_points, 2) + [1, 1], np.random.randn(half_num_points, 2)] y = [1] * half_num_points + [-1] * half_num_points sample_weight = 100 * np.abs(np.random.randn(num_points)) # and assign a bigger weight to the last 10 samples sample_weight[:half_num_points] *= 10 # plot the weighted data points xx, yy = np.meshgrid(np.linspace(-4, 5, 500), np.linspace(-4, 5, 500)) fig, ax = plt.subplots() ax.scatter( X[:, 0], X[:, 1], c=y, s=sample_weight, alpha=0.9, cmap=plt.cm.bone, edgecolor="black", ) # fit the unweighted model clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100) clf.fit(X, y) Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) no_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=["solid"]) # fit the weighted model clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100) clf.fit(X, y, sample_weight=sample_weight) Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) samples_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=["dashed"]) no_weights_handles, _ = no_weights.legend_elements() weights_handles, _ = samples_weights.legend_elements() ax.legend( [no_weights_handles[0], weights_handles[0]], ["no weights", "with weights"], loc="lower left", ) ax.set(xticks=(), yticks=()) return fig info = ''' # SGD: Weighted samples\n This is a demonstration of a modified version of [SGD](https://scikit-learn.org/stable/modules/sgd.html#id5) that takes into account the weights of the samples. Where the size of points is proportional to its weight.\n Created by [@Nahrawy](https://huggingface.co/Nahrawy) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_weighted_samples.html). ''' with gr.Blocks() as demo: gr.Markdown(info) with gr.Row(): with gr.Column(): seed = gr.Textbox(label="Seed", info="Leave empty to generate new random points each run ",value=None) num_points = gr.Slider(label="Number of Points", value="20", minimum=5, maximum=100, step=2) btn = gr.Button("Run") out = gr.Plot() btn.click(fn=plot, inputs=[seed,num_points] , outputs=out) demo.launch()