File size: 3,772 Bytes
e839c0e bad0412 e839c0e bad0412 e839c0e bad0412 e839c0e bad0412 e839c0e bad0412 e839c0e 3dd7e72 e839c0e 957d83a 229043f 99dfe9b 68c4a5f 99dfe9b 229043f 99dfe9b 229043f bad0412 e839c0e c95e646 48ad3a5 e4abb69 48ad3a5 c95e646 e839c0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs
def get_clusters_plot(n_blobs, quantile, cluster_std):
X, _, centers = make_blobs(
n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True
)
bandwidth = estimate_bandwidth(X, quantile=quantile, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
fig = plt.figure()
for k in range(n_clusters_):
my_members = labels == k
cluster_center = cluster_centers[k]
plt.scatter(X[my_members, 0], X[my_members, 1])
plt.plot(
cluster_center[0],
cluster_center[1],
"x",
markeredgecolor="k",
markersize=14,
)
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title(f"Estimated number of clusters: {n_clusters_}")
if len(centers) != n_clusters_:
message = (
'<p style="text-align: center;">'
+ f"The number of estimated clusters ({n_clusters_})"
+ f" differs from the true number of clusters ({n_blobs})."
+ " Try changing the `Quantile` parameter.</p>"
)
else:
message = (
'<p style="text-align: center;">'
+ f"The number of estimated clusters ({n_clusters_})"
+ f" matches the true number of clusters ({n_blobs})!</p>"
)
return fig, message
with gr.Blocks() as demo:
gr.Markdown(
"""
# Mean Shift Clustering
This space shows how to use the [Mean Shift Clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html) algorithm to cluster 2D data points. You can change the parameters using the sliders and see how the model performs.
This space is based on [sklearn's original demo](https://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py).
"""
)
with gr.Row():
with gr.Column(scale=1):
n_blobs = gr.Slider(
minimum=2,
maximum=10,
label="Number of clusters in the data",
step=1,
value=3,
)
quantile = gr.Slider(
minimum=0,
maximum=1,
step=0.05,
value=0.2,
label="Quantile",
info="Used to determine clustering's bandwidth.",
)
cluster_std = gr.Slider(
minimum=0.1,
maximum=1,
label="Clusters' standard deviation",
step=0.1,
value=0.6,
)
with gr.Column(scale=4):
clusters_plots = gr.Plot(label="Clusters' Plot")
message = gr.HTML()
n_blobs.change(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
quantile.change(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
cluster_std.change(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
demo.load(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
if __name__ == "__main__":
demo.launch()
|