import gradio as gr import matplotlib.pyplot as plt import numpy as np from sklearn.cluster import MeanShift, estimate_bandwidth from sklearn.datasets import make_blobs def get_clusters_plot(n_blobs, quantile, cluster_std): X, _, centers = make_blobs( n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True ) bandwidth = estimate_bandwidth(X, quantile=quantile, n_samples=500) ms = MeanShift(bandwidth=bandwidth, bin_seeding=True) ms.fit(X) labels = ms.labels_ cluster_centers = ms.cluster_centers_ labels_unique = np.unique(labels) n_clusters_ = len(labels_unique) fig = plt.figure() for k in range(n_clusters_): my_members = labels == k cluster_center = cluster_centers[k] plt.scatter(X[my_members, 0], X[my_members, 1]) plt.plot( cluster_center[0], cluster_center[1], "x", markeredgecolor="k", markersize=14, ) if len(centers) != n_clusters_: message = ( '
' + f"The number of estimated clusters ({n_clusters_})" + f" differs from the true number of clusters ({n_blobs})." + " Try changing the `Quantile` parameter.
" ) else: message = ( '' + f"The number of estimated clusters ({n_clusters_})" + f" matches the true number of clusters ({n_blobs})!
" ) return fig, message demo = gr.Interface( get_clusters_plot, [ gr.Slider( minimum=2, maximum=10, label="Number of clusters in data", step=1, value=3 ), gr.Slider( minimum=0, maximum=1, step=0.05, value=0.2, label="Quantile", info="Used to determine clustering's bandwidth.", ), gr.Slider( minimum=0.1, maximum=1, label="Clusters standard deviation", step=0.1, value=0.6, ), ], [gr.Plot(label="Clusters' Plot"), gr.HTML()], allow_flagging="never", ) if __name__ == "__main__": demo.launch()