LenixC commited on
Commit
3c6e069
·
1 Parent(s): b2f9756

Built Gradio implementation of the example.

Browse files
Files changed (2) hide show
  1. app.py +99 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio Implementation: Lenix Carter
2
+ # License: BSD 3-Clause or CC-0
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import matplotlib
7
+ import matplotlib.pyplot as plt
8
+
9
+ from sklearn import datasets
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.decomposition import PCA
12
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
13
+ from sklearn.neighbors import KNeighborsClassifier, NeighborhoodComponentsAnalysis
14
+ from sklearn.pipeline import make_pipeline
15
+ from sklearn.preprocessing import StandardScaler
16
+
17
+ matplotlib.use('agg')
18
+
19
+ def reduce_dimensions(n_neighbors, random_state):
20
+ # Load Digits dataset
21
+ X, y = datasets.load_digits(return_X_y=True)
22
+
23
+ # Split into train/test
24
+ X_train, X_test, y_train, y_test = train_test_split(
25
+ X, y, test_size=0.5, stratify=y, random_state=random_state
26
+ )
27
+
28
+ dim = len(X[0])
29
+ n_classes = len(np.unique(y))
30
+
31
+ # Reduce dimension to 2 with PCA
32
+ pca = make_pipeline(StandardScaler(), PCA(n_components=2, random_state=random_state))
33
+
34
+ # Reduce dimension to 2 with LinearDiscriminantAnalysis
35
+ lda = make_pipeline(StandardScaler(), LinearDiscriminantAnalysis(n_components=2))
36
+
37
+ # Reduce dimension to 2 with NeighborhoodComponentAnalysis
38
+ nca = make_pipeline(
39
+ StandardScaler(),
40
+ NeighborhoodComponentsAnalysis(n_components=2, random_state=random_state),
41
+ )
42
+
43
+ # Use a nearest neighbor classifier to evaluate the methods
44
+ knn = KNeighborsClassifier(n_neighbors=n_neighbors)
45
+
46
+ # Make a list of the methods to be compared
47
+ dim_reduction_methods = [("PCA", pca), ("LDA", lda), ("NCA", nca)]
48
+
49
+ dim_red_graphs = []
50
+
51
+ for i, (name, model) in enumerate(dim_reduction_methods):
52
+ new = plt.figure()
53
+
54
+ # Fit the method's model
55
+ model.fit(X_train, y_train)
56
+
57
+ # Fit a nearest neighbor classifier on the embedded training set
58
+ knn.fit(model.transform(X_train), y_train)
59
+
60
+ # Compute the nearest neighbor accuracy on the embedded test set
61
+ acc_knn = knn.score(model.transform(X_test), y_test)
62
+
63
+ # Embed the data set in 2 dimensions using the fitted model
64
+ X_embedded = model.transform(X)
65
+
66
+ # Plot the projected points and show the evaluation score
67
+ plt.scatter(X_embedded[:, 0], X_embedded[:, 1], c=y, s=30, cmap="Set1")
68
+ plt.title(
69
+ "{}, KNN (k={})\nTest accuracy = {:.2f}".format(name, n_neighbors, acc_knn)
70
+ )
71
+ dim_red_graphs.append(new)
72
+ return dim_red_graphs
73
+
74
+ title = "Dimensionality Reduction with Neighborhood Components Analysis"
75
+ with gr.Blocks() as demo:
76
+ gr.Markdown(f" # {title}")
77
+ gr.Markdown("""
78
+ This example performs and displays the results of Principal Component Analysis, Linear Descriminant Analysis, and Neighborhood Component Analysis on the Digits dataset.
79
+
80
+ The result shows that NCA produces visually meaningful clustering.
81
+
82
+ This based on the example [here](https://scikit-learn.org/stable/auto_examples/neighbors/plot_nca_dim_reduction.html#sphx-glr-auto-examples-neighbors-plot-nca-dim-reduction-py)
83
+ """)
84
+ n_neighbors = gr.Slider(2, 10, 3, step=1, label="Number of Neighbors for KNN")
85
+ random_state = gr.Slider(0, 100, 0, step=1, label="Random State")
86
+ btn = gr.Button(label="Run")
87
+ with gr.Row():
88
+ pca_graph = gr.Plot(label="PCA")
89
+ lda_graph = gr.Plot(label="LDA")
90
+ nca_graph = gr.Plot(label="NCA")
91
+ btn.click(
92
+ fn=reduce_dimensions,
93
+ inputs=[n_neighbors, random_state],
94
+ outputs=[pca_graph, lda_graph, nca_graph]
95
+ )
96
+
97
+ if __name__ == '__main__':
98
+ demo.launch()
99
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ matplotlib==3.6.3
2
+ scikit-learn==1.2.2