Raaniel commited on
Commit
64b69c1
·
1 Parent(s): 474f65f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors: Jona Sassenhagen
2
+ # License: BSD 3 clause
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from sklearn.decomposition import FactorAnalysis, PCA
7
+ from sklearn.preprocessing import StandardScaler
8
+ from sklearn.datasets import load_iris
9
+ import gradio as gr
10
+ from matplotlib import style
11
+ plt.switch_backend("agg")
12
+ style.use('ggplot')
13
+
14
+ font1 = {'family':'DejaVu Sans','color':'#2563EB','size': 14}
15
+
16
+ #load and transform the data
17
+ data = load_iris()
18
+ X = StandardScaler().fit_transform(data["data"])
19
+ feature_names = data["feature_names"]
20
+
21
+ methods = {
22
+ "PCA": PCA(),
23
+ "Unrotated FA": FactorAnalysis(),
24
+ "Varimax FA": FactorAnalysis(rotation="varimax")
25
+ }
26
+
27
+ def factor_analysis(method):
28
+ #figure1
29
+ fig1, ax = plt.subplots(figsize=(10, 6), facecolor='none', dpi = 200)
30
+ im = ax.imshow(np.corrcoef(X.T), cmap="Spectral", vmin=-1, vmax=1)
31
+
32
+ ax.set_xticks([0, 1, 2, 3])
33
+ ax.set_xticklabels(list(feature_names),
34
+ rotation=90, fontdict = font1)
35
+ ax.set_yticks([0, 1, 2, 3])
36
+ ax.set_yticklabels(list(feature_names), fontdict = font1)
37
+ plt.grid(False)
38
+ plt.colorbar(im).ax.tick_params()
39
+ ax.set_title("Iris feature correlation matrix",
40
+ fontdict=font1, size = 18,
41
+ color = "white", pad = 20,
42
+ bbox=dict(boxstyle="round,pad=0.3",
43
+ color = "#2563EB"))
44
+ plt.tight_layout()
45
+ plt.close('all')
46
+
47
+ n_comps = 2
48
+
49
+ #figure2
50
+ fig2, axs = plt.subplots(figsize=(8, 5), facecolor='none', dpi = 200)
51
+ plt.grid(False)
52
+ fa = methods[method]
53
+ fa.set_params(n_components=n_comps)
54
+ fa.fit(X)
55
+
56
+ components = fa.components_
57
+
58
+ vmax = np.abs(components).max()
59
+ axs.imshow(components, cmap="Spectral", vmax=vmax, vmin=-vmax)
60
+ axs.set_xticks(np.arange(len(feature_names)))
61
+ axs.set_xticklabels(feature_names, fontdict=font1)
62
+ axs.set_title(method,
63
+ fontdict=font1, size = 18,
64
+ color = "white", pad = 20,
65
+ bbox=dict(boxstyle="round,pad=0.3",
66
+ color = "#2563EB"))
67
+ axs.set_yticks([0, 1])
68
+ axs.set_yticklabels(["Comp. 1", "Comp. 2"], fontdict=font1)
69
+
70
+ plt.tight_layout()
71
+ plt.close('all')
72
+
73
+ return fig1, fig2, components
74
+
75
+ intro = """<h1 style="text-align: center;">🤗 <strong>Factor Analysis (with rotation) to visualize patterns</strong> 🤗</h1>
76
+ """
77
+ desc = """<h3 style="text-align: left;"> Investigating the Iris dataset, we see that sepal length, petal length and petal width are highly correlated.
78
+ Sepal width is less redundant. Matrix decomposition techniques can uncover these latent patterns.
79
+ <br><br>Applying rotations to the resulting components does not inherently improve the predictive value of the derived latent space,
80
+ but can help visualise their structure; here, for example, the varimax rotation,
81
+ which is found by maximizing the squared variances of the weights,
82
+ finds a structure where the second component only loads positively on sepal width.
83
+ <br></h3>
84
+ """
85
+
86
+ made ="""<div style="text-align: center;">
87
+ <p>Made with ❤</p>"""
88
+
89
+ link = """<div style="text-align: center;">
90
+ <a href="https://scikit-learn.org/stable/auto_examples/decomposition/plot_varimax_fa.html#sphx-glr-auto-examples-decomposition-plot-varimax-fa-py" target="_blank" rel="noopener noreferrer">
91
+ Demo is based on this script from scikit-learn documentation</a>"""
92
+
93
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue",
94
+ secondary_hue="sky",
95
+ neutral_hue="neutral",
96
+ font = gr.themes.GoogleFont("Roboto")),
97
+ title="Factor-Analysis-with-rotation") as demo:
98
+ gr.HTML(intro)
99
+ gr.HTML(desc)
100
+ method = gr.Radio(["PCA", "Unrotated FA", "Varimax FA"], label = "Choose method to show on the plot:", value = "PCA")
101
+ with gr.Box():
102
+ with gr.Column():
103
+ components = gr.Dataframe(headers= feature_names,label = "Loadings")
104
+ with gr.Row():
105
+ fig1 = gr.Plot(label="Plot covariance of Iris features")
106
+ fig2 = gr.Plot(label="Factor analysis")
107
+
108
+ method.change(fn=factor_analysis, inputs=method, outputs=[fig1, fig2, components])
109
+ demo.load(fn=factor_analysis, inputs=method, outputs=[fig1, fig2, components])
110
+ gr.HTML(made)
111
+ gr.HTML(link)
112
+
113
+ demo.launch()