|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
from sklearn import svm, datasets |
|
from sklearn.inspection import DecisionBoundaryDisplay |
|
|
|
def plot_svm_classifiers(): |
|
|
|
iris = datasets.load_iris() |
|
|
|
X = iris.data[:, :2] |
|
y = iris.target |
|
|
|
|
|
|
|
C = 1.0 |
|
models = ( |
|
svm.SVC(kernel="linear", C=C), |
|
svm.LinearSVC(C=C, max_iter=10000), |
|
svm.SVC(kernel="rbf", gamma=0.7, C=C), |
|
svm.SVC(kernel="poly", degree=3, gamma="auto", C=C), |
|
) |
|
models = (clf.fit(X, y) for clf in models) |
|
|
|
|
|
titles = ( |
|
"SVC with linear kernel", |
|
"LinearSVC (linear kernel)", |
|
"SVC with RBF kernel", |
|
"SVC with polynomial (degree 3) kernel", |
|
) |
|
|
|
|
|
fig, sub = plt.subplots(2, 2) |
|
plt.subplots_adjust(wspace=0.4, hspace=0.4) |
|
|
|
X0, X1 = X[:, 0], X[:, 1] |
|
|
|
for clf, title, ax in zip(models, titles, sub.flatten()): |
|
disp = DecisionBoundaryDisplay.from_estimator( |
|
clf, |
|
X, |
|
response_method="predict", |
|
cmap=plt.cm.coolwarm, |
|
alpha=0.8, |
|
ax=ax, |
|
xlabel=iris.feature_names[0], |
|
ylabel=iris.feature_names[1], |
|
) |
|
ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors="k") |
|
ax.set_xticks(()) |
|
ax.set_yticks(()) |
|
ax.set_title(title) |
|
plt.axis('tight') |
|
|
|
return fig |
|
|
|
heading = 'π€π§‘π€π Plot different SVM Classifiers on Iris Dataset' |
|
|
|
with gr.Blocks(title = heading, theme= 'snehilsanyal/scikit-learn') as demo: |
|
gr.Markdown("# {}".format(heading)) |
|
gr.Markdown( |
|
""" |
|
### This demo visualizes different SVM Classifiers on a 2D projection of the Iris dataset. |
|
|
|
<b>The features to be considered are:</b>\ |
|
\ |
|
|
|
1. Sepal length (cm) \ |
|
|
|
2. Sepal width (cm) \ |
|
|
|
<b>The SVM Classifiers used for this demo are:</b>\ |
|
\ |
|
|
|
1. SVC with linear kernel \ |
|
|
|
2. Linear SVC \ |
|
|
|
3. SVC with RBF kernel\ |
|
|
|
4. SVC with Polynomial (degree 3) kernel |
|
""" |
|
) |
|
gr.Markdown('**[Demo is based on this script from scikit-learn documentation](https://scikit-learn.org/stable/auto_examples/svm/plot_iris_svc.html#sphx-glr-auto-examples-svm-plot-iris-svc-py)**') |
|
button = gr.Button(value = 'Visualize different SVM Classifiers on Iris Dataset') |
|
button.click(plot_svm_classifiers, outputs = gr.Plot()) |
|
|
|
demo.launch() |