Raaniel's picture
Upload 2 files
9e585d8
raw
history blame
4.32 kB
# Authors: Jona Sassenhagen
# License: BSD 3 clause
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import FactorAnalysis, PCA
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris
import gradio as gr
from matplotlib import style
plt.switch_backend("agg")
style.use('ggplot')
font1 = {'family':'Comic Sans MS','color':'#2563EB','size': 14}
#load and transform the data
data = load_iris()
X = StandardScaler().fit_transform(data["data"])
feature_names = data["feature_names"]
methods = {
"PCA": PCA(),
"Unrotated FA": FactorAnalysis(),
"Varimax FA": FactorAnalysis(rotation="varimax")
}
def factor_analysis(method):
#figure1
fig1, ax = plt.subplots(figsize=(10, 6), facecolor='none', dpi = 200)
im = ax.imshow(np.corrcoef(X.T), cmap="Spectral", vmin=-1, vmax=1)
ax.set_xticks([0, 1, 2, 3])
ax.set_xticklabels(list(feature_names),
rotation=90, fontdict = font1)
ax.set_yticks([0, 1, 2, 3])
ax.set_yticklabels(list(feature_names), fontdict = font1)
plt.grid(False)
plt.colorbar(im).ax.tick_params()
ax.set_title("Iris feature correlation matrix",
fontdict=font1, size = 18,
color = "white", pad = 20,
bbox=dict(boxstyle="round,pad=0.3",
color = "#2563EB"))
plt.tight_layout()
plt.close('all')
n_comps = 2
#figure2
fig2, axs = plt.subplots(figsize=(8, 5), facecolor='none', dpi = 200)
plt.grid(False)
fa = methods[method]
fa.set_params(n_components=n_comps)
fa.fit(X)
components = fa.components_
vmax = np.abs(components).max()
axs.imshow(components, cmap="Spectral", vmax=vmax, vmin=-vmax)
axs.set_xticks(np.arange(len(feature_names)))
axs.set_xticklabels(feature_names, fontdict=font1)
axs.set_title(method,
fontdict=font1, size = 18,
color = "white", pad = 20,
bbox=dict(boxstyle="round,pad=0.3",
color = "#2563EB"))
axs.set_yticks([0, 1])
axs.set_yticklabels(["Comp. 1", "Comp. 2"], fontdict=font1)
plt.tight_layout()
plt.close('all')
return fig1, fig2, components
intro = """<h1 style="text-align: center;">πŸ€— <strong>Factor Analysis (with rotation) to visualize patterns</strong> πŸ€—</h1>
"""
desc = """<h3 style="text-align: left;"> Investigating the Iris dataset, we see that sepal length, petal length and petal width are highly correlated.
Sepal width is less redundant. Matrix decomposition techniques can uncover these latent patterns.
<br><br>Applying rotations to the resulting components does not inherently improve the predictive value of the derived latent space,
but can help visualise their structure; here, for example, the varimax rotation,
which is found by maximizing the squared variances of the weights,
finds a structure where the second component only loads positively on sepal width.
<br></h3>
"""
made ="""<div style="text-align: center;">
<p>Made with ❀</p>"""
link = """<div style="text-align: center;">
<a href="https://scikit-learn.org/stable/auto_examples/decomposition/plot_varimax_fa.html#sphx-glr-auto-examples-decomposition-plot-varimax-fa-py" target="_blank" rel="noopener noreferrer">
Demo is based on this script from scikit-learn documentation</a>"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue",
secondary_hue="sky",
neutral_hue="neutral",
font = gr.themes.GoogleFont("Roboto")),
title="Factor-Analysis-with-rotation") as demo:
gr.HTML(intro)
gr.HTML(desc)
method = gr.Radio(["PCA", "Unrotated FA", "Varimax FA"], label = "Choose method to show on the plot:", value = "PCA")
btn = gr.Button()
with gr.Box():
with gr.Column():
components = gr.Dataframe(headers= feature_names,label = "Loadings")
with gr.Row():
fig1 = gr.Plot(label="Plot covariance of Iris features")
fig2 = gr.Plot(label="Factor analysis")
btn.click(fn=factor_analysis, inputs=method, outputs=[fig1, fig2, components])
gr.HTML(made)
gr.HTML(link)
demo.launch()