caliex commited on
Commit
4eb7f49
·
1 Parent(s): faaf3b2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from matplotlib import pyplot as plt
4
+ from matplotlib.collections import LineCollection
5
+ from sklearn import manifold
6
+ from sklearn.metrics import euclidean_distances
7
+ from sklearn.decomposition import PCA
8
+
9
+ EPSILON = np.finfo(np.float32).eps
10
+ n_samples = 20
11
+ seed = np.random.RandomState(seed=3)
12
+ X_true = seed.randint(0, 20, 2 * n_samples).astype(float)
13
+ X_true = X_true.reshape((n_samples, 2))
14
+ # Center the data
15
+ X_true -= X_true.mean()
16
+
17
+ similarities = euclidean_distances(X_true)
18
+
19
+ # Add noise to the similarities
20
+ noise = np.random.rand(n_samples, n_samples)
21
+ noise = noise + noise.T
22
+ noise[np.arange(noise.shape[0]), np.arange(noise.shape[0])] = 0
23
+ similarities += noise
24
+
25
+ def mds_nmds(n_components, max_iter, eps):
26
+ mds = manifold.MDS(
27
+ n_components=n_components,
28
+ max_iter=max_iter,
29
+ eps=eps,
30
+ random_state=seed,
31
+ dissimilarity="precomputed",
32
+ n_jobs=1,
33
+ normalized_stress="auto",
34
+ )
35
+ pos = mds.fit(similarities).embedding_
36
+
37
+ nmds = manifold.MDS(
38
+ n_components=n_components,
39
+ metric=False,
40
+ max_iter=max_iter,
41
+ eps=eps,
42
+ dissimilarity="precomputed",
43
+ random_state=seed,
44
+ n_jobs=1,
45
+ n_init=1,
46
+ normalized_stress="auto",
47
+ )
48
+ npos = nmds.fit_transform(similarities, init=pos)
49
+
50
+ # Rescale the data
51
+ pos *= np.sqrt((X_true**2).sum()) / np.sqrt((pos**2).sum())
52
+ npos *= np.sqrt((X_true**2).sum()) / np.sqrt((npos**2).sum())
53
+
54
+ # Rotate the data
55
+ clf = PCA(n_components=2)
56
+ X_true_transformed = clf.fit_transform(X_true)
57
+ pos_transformed = clf.fit_transform(pos)
58
+ npos_transformed = clf.fit_transform(npos)
59
+
60
+ return X_true_transformed, pos_transformed, npos_transformed
61
+
62
+
63
+ def plot_similarity_scatter(similarity_threshold=50, n_components=2, max_iter=3000, eps=1e-9, cmap_name='Blues'):
64
+ X_true_transformed, pos_transformed, npos_transformed = mds_nmds(n_components, max_iter, eps)
65
+
66
+ fig = plt.figure()
67
+ ax = plt.axes([0.0, 0.0, 1.0, 1.0])
68
+
69
+ s = 100
70
+ plt.scatter(X_true_transformed[:, 0], X_true_transformed[:, 1], color="navy", s=s, lw=0, label="True Position")
71
+ plt.scatter(pos_transformed[:, 0], pos_transformed[:, 1], color="turquoise", s=s, lw=0, label="MDS")
72
+ plt.scatter(npos_transformed[:, 0], npos_transformed[:, 1], color="darkorange", s=s, lw=0, label="NMDS")
73
+ plt.legend(scatterpoints=1, loc="best", shadow=False)
74
+
75
+ similarities_thresholded = similarities.copy()
76
+ similarities_thresholded[similarities_thresholded <= int(similarity_threshold)] = 0
77
+
78
+ np.fill_diagonal(similarities_thresholded, 0)
79
+ # Plot the edges
80
+ start_idx, end_idx = np.where(pos_transformed)
81
+ segments = [[X_true_transformed[i, :], X_true_transformed[j, :]] for i in range(len(pos_transformed)) for j in range(len(pos_transformed))]
82
+ values = np.abs(similarities_thresholded)
83
+ lc = LineCollection(segments, zorder=0, cmap=plt.cm.get_cmap(cmap_name), norm=plt.Normalize(0, values.max()))
84
+ lc.set_array(similarities_thresholded.flatten())
85
+ lc.set_linewidths(np.full(len(segments), 0.5))
86
+ ax.add_collection(lc)
87
+
88
+ # Save the plot as a PNG file
89
+ plt.savefig("plot.png")
90
+ plt.close()
91
+
92
+ # Return the saved plot file
93
+ return "plot.png"
94
+
95
+
96
+
97
+ parameters = [
98
+ gr.inputs.Slider(label="Similarity Threshold", minimum=0, maximum=100, step=1, default=50),
99
+ gr.inputs.Slider(label="Number of Components", minimum=1, maximum=10, step=1, default=2),
100
+ gr.inputs.Slider(label="Max Iterations", minimum=100, maximum=5000, step=100, default=3000),
101
+ gr.inputs.Slider(label="Epsilon", minimum=1e-12, maximum=1e-6, step=1e-12, default=1e-9),
102
+ gr.inputs.Dropdown(label="Colormap", choices=["Blues_r", "Dark2", "Reds_r", "Purples_r"], default="Blues_r")
103
+ ]
104
+
105
+
106
+ iface = gr.Interface(
107
+ fn=plot_similarity_scatter,
108
+ inputs=parameters,
109
+ outputs="image",
110
+ title="Multi-dimensional scaling",
111
+ description="The scatter plot is generated based on the provided data and similarity matrix. MDS and NMDS techniques are used to project the data points into a two-dimensional space. The points are plotted in the scatter plot, with different colors representing the true positions, MDS positions, and NMDS positions of the data points. The similarity threshold parameter allows you to control the visibility of connections between the points. Points with similarity values below the threshold are not connected by lines in the plot. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/manifold/plot_mds.html",
112
+ examples=[
113
+ [50, 2, 3000, 1e-9, "Blues_r"],
114
+ [75, 3, 2000, 1e-10, "Dark2"],
115
+ [90, 2, 4000, 1e-11, "Reds_r"],
116
+ ],
117
+
118
+ )
119
+
120
+ iface.launch()