|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import numpy as np |
|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patheffects as PathEffects |
|
|
|
from sklearn.cluster import AgglomerativeClustering |
|
from sklearn.metrics import pairwise_distances |
|
|
|
np.random.seed(0) |
|
matplotlib.use('agg') |
|
labels = ("Waveform 1", "Waveform 2", "Waveform 3") |
|
colors = ["#f7bd01", "#377eb8", "#f781bf"] |
|
n_clusters = 3 |
|
|
|
def sqr(x): |
|
return np.sign(np.cos(x)) |
|
|
|
def ground_truth_plot(n_features): |
|
t = np.pi * np.linspace(0, 1, n_features) |
|
|
|
X = list() |
|
y = list() |
|
for i, (phi, a) in enumerate([(0.5, 0.15), (0.5, 0.6), (0.3, 0.2)]): |
|
for _ in range(30): |
|
phase_noise = 0.01 * np.random.normal() |
|
amplitude_noise = 0.04 * np.random.normal() |
|
additional_noise = 1 - 2 * np.random.rand(n_features) |
|
|
|
additional_noise[np.abs(additional_noise) < 0.997] = 0 |
|
|
|
X.append( |
|
12 |
|
* ( |
|
(a + amplitude_noise) * (sqr(6 * (t + phi + phase_noise))) |
|
+ additional_noise |
|
) |
|
) |
|
y.append(i) |
|
|
|
X = np.array(X) |
|
y = np.array(y) |
|
|
|
gt_plot, ax = plt.subplots() |
|
|
|
for l, color, n in zip(range(n_clusters), colors, labels): |
|
lines = plt.plot(X[y == l].T, c=color, alpha=0.5) |
|
lines[0].set_label(n) |
|
|
|
plt.subplots_adjust(top=0.8, bottom=0, left=0, right=1.0) |
|
ax.set_title("Ground Truth", size=20, pad=1) |
|
plt.legend(loc="best") |
|
plt.axis("off") |
|
|
|
return gt_plot, X, y |
|
|
|
def plot_cluster_waves(metric, X, y): |
|
model = AgglomerativeClustering( |
|
n_clusters=n_clusters, linkage="average", metric=metric |
|
) |
|
model.fit(X) |
|
|
|
clust_plot, ax = plt.subplots() |
|
for l, color in zip(np.arange(model.n_clusters), colors): |
|
plt.plot(X[model.labels_ == l].T, c=color, alpha=0.5) |
|
|
|
plt.subplots_adjust(top=0.75, bottom=0, left=0, right=1.0) |
|
ax.set_title("Agglomerative Clustering\n(metric=%s)" % metric, size=20, pad=1.0) |
|
plt.axis("tight") |
|
plt.axis("off") |
|
return clust_plot |
|
|
|
def plot_distances(metric, X, y): |
|
avg_dist = np.zeros((n_clusters, n_clusters)) |
|
dist_plot, ax = plt.subplots() |
|
|
|
for i in range(n_clusters): |
|
for j in range(n_clusters): |
|
avg_dist[i, j] = pairwise_distances( |
|
X[y == i], X[y == j], metric=metric |
|
).mean() |
|
avg_dist /= avg_dist.max() |
|
for i in range(n_clusters): |
|
for j in range(n_clusters): |
|
t = plt.text( |
|
i, |
|
j, |
|
"%5.3f" % avg_dist[i, j], |
|
verticalalignment="center", |
|
horizontalalignment="center", |
|
) |
|
t.set_path_effects( |
|
[PathEffects.withStroke(linewidth=5, foreground="w", alpha=0.5)] |
|
) |
|
|
|
plt.imshow(avg_dist, interpolation="nearest", cmap="cividis", vmin=0) |
|
plt.xticks(range(n_clusters), labels, rotation=45) |
|
plt.yticks(range(n_clusters), labels) |
|
plt.colorbar() |
|
plt.subplots_adjust(top=0.8) |
|
ax.set_title("Interclass %s distances" % metric, size=20, pad=1.0) |
|
plt.axis("off") |
|
return dist_plot |
|
|
|
def agg_cluster(n_feats, measure): |
|
plt.clf() |
|
gt_plt, X, y = ground_truth_plot(n_feats) |
|
cluster_waves_plot = plot_cluster_waves(measure, X, y) |
|
dist_plot = plot_distances(measure, X, y) |
|
return gt_plt, cluster_waves_plot, dist_plot |
|
|
|
title = "Agglomerative clustering with different metrics" |
|
with gr.Blocks() as demo: |
|
gr.Markdown(f" # {title}") |
|
gr.Markdown( |
|
""" |
|
This example demonstrates the effect of different metrics on hierarchical clustering. |
|
|
|
This is based on the example [here](https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering_metrics.html#sphx-glr-auto-examples-cluster-plot-agglomerative-clustering-metrics-py) |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
n_feats = gr.Slider(10, 4000, 2000, label="Number of Features") |
|
measure = gr.Radio(["cosine", "euclidean", "cityblock"], label="Metric", value="cosine") |
|
gt_graph = gr.Plot(label="Ground Truth Graph") |
|
gt_graph.style() |
|
with gr.Row(): |
|
dist_plot = gr.Plot(label="Interclass Distances") |
|
clust_waves = gr.Plot(label="Agglomerative Clustering") |
|
|
|
n_feats.change( |
|
fn=agg_cluster, |
|
inputs=[n_feats, measure], |
|
outputs=[gt_graph, clust_waves, dist_plot] |
|
) |
|
measure.change( |
|
fn=agg_cluster, |
|
inputs=[n_feats, measure], |
|
outputs=[gt_graph, clust_waves, dist_plot] |
|
) |
|
|
|
if __name__ == '__main__': |
|
demo.launch() |
|
|
|
|