LenixC commited on
Commit
0359ac6
·
1 Parent(s): 4ba5ee9

Added core functionality to gradio implementation.

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original Author: Gael Varoquaux
2
+ # Gradio Implementation: Lenix Carter
3
+ # License: BSD 3-Clause or CC-0
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.patheffects as PathEffects
10
+
11
+ from sklearn.cluster import AgglomerativeClustering
12
+ from sklearn.metrics import pairwise_distances
13
+
14
+ np.random.seed(0)
15
+ matplotlib.use('agg')
16
+ labels = ("Waveform 1", "Waveform 2", "Waveform 3")
17
+ colors = ["#f7bd01", "#377eb8", "#f781bf"]
18
+ n_clusters = 3
19
+
20
+ def sqr(x):
21
+ return np.sign(np.cos(x))
22
+
23
+ def ground_truth_plot(n_features):
24
+ t = np.pi * np.linspace(0, 1, n_features)
25
+
26
+ X = list()
27
+ y = list()
28
+ for i, (phi, a) in enumerate([(0.5, 0.15), (0.5, 0.6), (0.3, 0.2)]):
29
+ for _ in range(30):
30
+ phase_noise = 0.01 * np.random.normal()
31
+ amplitude_noise = 0.04 * np.random.normal()
32
+ additional_noise = 1 - 2 * np.random.rand(n_features)
33
+ # Make the noise sparse
34
+ additional_noise[np.abs(additional_noise) < 0.997] = 0
35
+
36
+ X.append(
37
+ 12
38
+ * (
39
+ (a + amplitude_noise) * (sqr(6 * (t + phi + phase_noise)))
40
+ + additional_noise
41
+ )
42
+ )
43
+ y.append(i)
44
+
45
+ X = np.array(X)
46
+ y = np.array(y)
47
+
48
+ # Plot the ground-truth labelling
49
+ gt_plot = plt.figure()
50
+ plt.axes([0, 0, 1, 1])
51
+ for l, color, n in zip(range(n_clusters), colors, labels):
52
+ lines = plt.plot(X[y == l].T, c=color, alpha=0.5)
53
+ lines[0].set_label(n)
54
+
55
+ plt.legend(loc="best")
56
+
57
+ plt.axis("tight")
58
+ plt.axis("off")
59
+ plt.suptitle("Ground truth", size=20, y=1)
60
+ return gt_plot, X, y
61
+
62
+ def plot_cluster_waves(metric, X, y):
63
+ model = AgglomerativeClustering(
64
+ n_clusters=n_clusters, linkage="average", metric=metric
65
+ )
66
+ model.fit(X)
67
+ clust_plot = plt.figure()
68
+ plt.axes([0, 0, 1, 1])
69
+ for l, color in zip(np.arange(model.n_clusters), colors):
70
+ plt.plot(X[model.labels_ == l].T, c=color, alpha=0.5)
71
+ plt.axis("tight")
72
+ plt.axis("off")
73
+ plt.suptitle("AgglomerativeClustering(metric=%s)" % metric, size=20, y=1)
74
+ return clust_plot
75
+
76
+ def plot_distances(metric, X, y):
77
+ avg_dist = np.zeros((n_clusters, n_clusters))
78
+ dist_plot = plt.figure()
79
+ for i in range(n_clusters):
80
+ for j in range(n_clusters):
81
+ avg_dist[i, j] = pairwise_distances(
82
+ X[y == i], X[y == j], metric=metric
83
+ ).mean()
84
+ avg_dist /= avg_dist.max()
85
+ for i in range(n_clusters):
86
+ for j in range(n_clusters):
87
+ t = plt.text(
88
+ i,
89
+ j,
90
+ "%5.3f" % avg_dist[i, j],
91
+ verticalalignment="center",
92
+ horizontalalignment="center",
93
+ )
94
+ t.set_path_effects(
95
+ [PathEffects.withStroke(linewidth=5, foreground="w", alpha=0.5)]
96
+ )
97
+
98
+ plt.imshow(avg_dist, interpolation="nearest", cmap="cividis", vmin=0)
99
+ plt.xticks(range(n_clusters), labels, rotation=45)
100
+ plt.yticks(range(n_clusters), labels)
101
+ plt.colorbar()
102
+ plt.suptitle("Interclass %s distances" % metric, size=18, y=1)
103
+ plt.tight_layout()
104
+ return dist_plot
105
+
106
+ def agg_cluster(n_feats, measure):
107
+ plt.clf()
108
+ gt_plt, X, y = ground_truth_plot(n_feats)
109
+ cluster_waves_plot = plot_cluster_waves(measure, X, y)
110
+ dist_plot = plot_distances(measure, X, y)
111
+ return gt_plt, cluster_waves_plot, dist_plot
112
+
113
+ with gr.Blocks() as demo:
114
+ with gr.Row():
115
+ with gr.Column():
116
+ n_feats = gr.Slider(10, 4000, 2000, label="Number of Features")
117
+ measure = gr.Dropdown(["cosine", "euclidean", "cityblock"], value="cosine")
118
+ btn = gr.Button(label="Run")
119
+ gt_graph = gr.Plot(label="Ground Truth Graph")
120
+ with gr.Row():
121
+ dist_plot = gr.Plot(label="Interclass Distances")
122
+ clust_waves = gr.Plot(label="Agglomerative Clustering")
123
+
124
+ btn.click(
125
+ fn=agg_cluster,
126
+ inputs=[n_feats, measure],
127
+ outputs=[gt_graph, clust_waves, dist_plot]
128
+ )
129
+
130
+ if __name__ == '__main__':
131
+ demo.launch()
132
+