Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from sklearn.datasets import load_iris
|
4 |
+
from sklearn.svm import SVC
|
5 |
+
from sklearn.model_selection import StratifiedKFold, permutation_test_score
|
6 |
+
import numpy as np
|
7 |
+
import tempfile
|
8 |
+
import os
|
9 |
+
|
10 |
+
def run_permutation_test(display_option, kernel, random_state, n_permutations):
|
11 |
+
iris = load_iris()
|
12 |
+
X = iris.data
|
13 |
+
y = iris.target
|
14 |
+
|
15 |
+
n_uncorrelated_features = 20
|
16 |
+
rng = np.random.RandomState(seed=0)
|
17 |
+
X_rand = rng.normal(size=(X.shape[0], n_uncorrelated_features))
|
18 |
+
|
19 |
+
clf = SVC(kernel=kernel, random_state=random_state)
|
20 |
+
cv = StratifiedKFold(2, shuffle=True, random_state=0)
|
21 |
+
|
22 |
+
score_iris, perm_scores_iris, pvalue_iris = permutation_test_score(
|
23 |
+
clf, X, y, scoring="accuracy", cv=cv, n_permutations=n_permutations
|
24 |
+
)
|
25 |
+
|
26 |
+
score_rand, perm_scores_rand, pvalue_rand = permutation_test_score(
|
27 |
+
clf, X_rand, y, scoring="accuracy", cv=cv, n_permutations=n_permutations
|
28 |
+
)
|
29 |
+
|
30 |
+
original_plot_path = None
|
31 |
+
random_plot_path = None
|
32 |
+
|
33 |
+
if display_option in ['original', 'both']:
|
34 |
+
# Original data
|
35 |
+
fig, ax = plt.subplots()
|
36 |
+
ax.hist(perm_scores_iris, bins=20, density=True)
|
37 |
+
ax.axvline(score_iris, ls="--", color="r")
|
38 |
+
score_label = f"Score on original\ndata: {score_iris:.2f}\n(p-value: {pvalue_iris:.3f})"
|
39 |
+
ax.text(0.7, 10, score_label, fontsize=12)
|
40 |
+
ax.set_xlabel("Accuracy score")
|
41 |
+
ax.set_ylabel("Probability")
|
42 |
+
original_plot_path = os.path.join(tempfile.mkdtemp(), "original_plot.png")
|
43 |
+
plt.savefig(original_plot_path)
|
44 |
+
plt.close()
|
45 |
+
|
46 |
+
if display_option in ['random', 'both']:
|
47 |
+
# Random data
|
48 |
+
fig, ax = plt.subplots()
|
49 |
+
ax.hist(perm_scores_rand, bins=20, density=True)
|
50 |
+
ax.set_xlim(0.13)
|
51 |
+
ax.axvline(score_rand, ls="--", color="r")
|
52 |
+
score_label = f"Score on original\ndata: {score_rand:.2f}\n(p-value: {pvalue_rand:.3f})"
|
53 |
+
ax.text(0.14, 7.5, score_label, fontsize=12)
|
54 |
+
ax.set_xlabel("Accuracy score")
|
55 |
+
ax.set_ylabel("Probability")
|
56 |
+
random_plot_path = os.path.join(tempfile.mkdtemp(), "random_plot.png")
|
57 |
+
plt.savefig(random_plot_path)
|
58 |
+
plt.close()
|
59 |
+
|
60 |
+
return original_plot_path, random_plot_path
|
61 |
+
|
62 |
+
iface = gr.Interface(
|
63 |
+
fn=run_permutation_test,
|
64 |
+
inputs=[
|
65 |
+
gr.inputs.Dropdown(
|
66 |
+
choices=["original", "random", "both"],
|
67 |
+
label="Display Option",
|
68 |
+
default="both"
|
69 |
+
),
|
70 |
+
gr.inputs.Dropdown(
|
71 |
+
choices=["linear", "rbf", "poly"],
|
72 |
+
label="Kernel",
|
73 |
+
default="linear"
|
74 |
+
),
|
75 |
+
gr.inputs.Slider(
|
76 |
+
minimum=0, maximum=10, step=1,
|
77 |
+
label="Random State",
|
78 |
+
default=7
|
79 |
+
),
|
80 |
+
gr.inputs.Slider(
|
81 |
+
minimum=100, maximum=2000, step=100,
|
82 |
+
label="Number of Permutations",
|
83 |
+
default=1000
|
84 |
+
)
|
85 |
+
],
|
86 |
+
outputs=["image", "image"],
|
87 |
+
title="Test with permutations the significance of a classification score",
|
88 |
+
description="This example demonstrates the use of permutation_test_score to evaluate the significance of a cross-validated score using permutations. This operation is being performed on the Iris Dataset. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/model_selection/plot_permutation_tests_for_classification.html",
|
89 |
+
examples=[
|
90 |
+
["both", "linear", 7, 1000],
|
91 |
+
["original", "rbf", 3, 500],
|
92 |
+
["random", "poly", 5, 1500]
|
93 |
+
],
|
94 |
+
allow_flagging=False
|
95 |
+
)
|
96 |
+
iface.launch()
|