Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import json | |
import math | |
import os | |
os.system("pip uninstall -y gradio") | |
os.system("pip install gradio==3.26.0") | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
from sklearn.datasets import fetch_20newsgroups | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.model_selection import RandomizedSearchCV | |
from sklearn.naive_bayes import ComplementNB | |
from sklearn.pipeline import Pipeline | |
CATEGORIES = [ | |
"alt.atheism", | |
"comp.graphics", | |
"comp.os.ms-windows.misc", | |
"comp.sys.ibm.pc.hardware", | |
"comp.sys.mac.hardware", | |
"comp.windows.x", | |
"misc.forsale", | |
"rec.autos", | |
"rec.motorcycles", | |
"rec.sport.baseball", | |
"rec.sport.hockey", | |
"sci.crypt", | |
"sci.electronics", | |
"sci.med", | |
"sci.space", | |
"soc.religion.christian", | |
"talk.politics.guns", | |
"talk.politics.mideast", | |
"talk.politics.misc", | |
"talk.religion.misc", | |
] | |
def shorten_param(param_name): | |
"""Remove components' prefixes in param_name.""" | |
if "__" in param_name: | |
return param_name.rsplit("__", 1)[1] | |
return param_name | |
def train_model(categories, vect__max_df, vect__min_df, vect__ngram_range, vect__norm): | |
pipeline = Pipeline( | |
[ | |
("vect", TfidfVectorizer()), | |
("clf", ComplementNB()), | |
] | |
) | |
parameters_grid = { | |
"vect__max_df": [eval(value) for value in vect__max_df.split(",")], | |
"vect__min_df": [eval(value) for value in vect__min_df.split(",")], | |
"vect__ngram_range": eval(vect__ngram_range), # unigrams or bigrams | |
"vect__norm": [value.strip() for value in vect__norm.split(",")], | |
"clf__alpha": np.logspace(-6, 6, 13), | |
} | |
print(parameters_grid) | |
data_train = fetch_20newsgroups( | |
subset="train", | |
categories=categories, | |
shuffle=True, | |
random_state=42, | |
remove=("headers", "footers", "quotes"), | |
) | |
data_test = fetch_20newsgroups( | |
subset="test", | |
categories=categories, | |
shuffle=True, | |
random_state=42, | |
remove=("headers", "footers", "quotes"), | |
) | |
pipeline = Pipeline( | |
[ | |
("vect", TfidfVectorizer()), | |
("clf", ComplementNB()), | |
] | |
) | |
random_search = RandomizedSearchCV( | |
estimator=pipeline, | |
param_distributions=parameters_grid, | |
n_iter=40, | |
random_state=0, | |
n_jobs=2, | |
verbose=1, | |
) | |
random_search.fit(data_train.data, data_train.target) | |
best_parameters = json.dumps( | |
random_search.best_estimator_.get_params(), | |
indent=4, | |
sort_keys=True, | |
default=str, | |
) | |
test_accuracy = random_search.score(data_test.data, data_test.target) | |
cv_results = pd.DataFrame(random_search.cv_results_) | |
cv_results = cv_results.rename(shorten_param, axis=1) | |
param_names = [shorten_param(name) for name in parameters_grid.keys()] | |
labels = { | |
"mean_score_time": "CV Score time (s)", | |
"mean_test_score": "CV score (accuracy)", | |
} | |
fig = px.scatter( | |
cv_results, | |
x="mean_score_time", | |
y="mean_test_score", | |
error_x="std_score_time", | |
error_y="std_test_score", | |
hover_data=param_names, | |
labels=labels, | |
) | |
fig.update_layout( | |
title={ | |
"text": "trade-off between scoring time and mean test score", | |
"y": 0.95, | |
"x": 0.5, | |
"xanchor": "center", | |
"yanchor": "top", | |
} | |
) | |
column_results = param_names + ["mean_test_score", "mean_score_time"] | |
transform_funcs = dict.fromkeys(column_results, lambda x: x) | |
# Using a logarithmic scale for alpha | |
transform_funcs["alpha"] = math.log10 | |
# L1 norms are mapped to index 1, and L2 norms to index 2 | |
transform_funcs["norm"] = lambda x: 2 if x == "l2" else 1 | |
# Unigrams are mapped to index 1 and bigrams to index 2 | |
transform_funcs["ngram_range"] = lambda x: x[1] | |
fig2 = px.parallel_coordinates( | |
cv_results[column_results].apply(transform_funcs), | |
color="mean_test_score", | |
color_continuous_scale=px.colors.sequential.Viridis_r, | |
labels=labels, | |
) | |
fig2.update_layout( | |
title={ | |
"text": "Parallel coordinates plot of text classifier pipeline", | |
"y": 0.99, | |
"x": 0.5, | |
"xanchor": "center", | |
"yanchor": "top", | |
} | |
) | |
return fig, fig2, best_parameters, test_accuracy | |
def load_description(name): | |
with open(f"./descriptions/{name}.md", "r") as f: | |
return f.read() | |
AUTHOR = """ | |
Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html) | |
""" | |
with gr.Blocks(theme=gr.themes.Soft()) as app: | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("# Sample pipeline for text feature extraction and evaluation") | |
gr.Markdown(load_description("description_part1")) | |
gr.Markdown(load_description("description_part2")) | |
gr.Markdown(AUTHOR) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("""## CATEGORY SELECTION""") | |
gr.Markdown(load_description("description_category_selection")) | |
drop_categories = gr.Dropdown( | |
CATEGORIES, | |
value=["alt.atheism", "talk.religion.misc"], | |
multiselect=True, | |
label="Categories", | |
info="Please select up to two categories that you want to receive training on.", | |
max_choices=2, | |
interactive=True, | |
) | |
with gr.Row(): | |
with gr.Tab("PARAMETERS GRID"): | |
gr.Markdown(load_description("description_parameter_grid")) | |
with gr.Row(): | |
with gr.Column(): | |
clf__alpha = gr.Textbox( | |
label="Classifier Alpha (clf__alpha)", | |
value="1.e-06, 1.e-05, 1.e-04", | |
info="Due to practical considerations, this parameter was kept constant.", | |
interactive=False, | |
) | |
vect__max_df = gr.Textbox( | |
label="Vectorizer max_df (vect__max_df)", | |
value="0.2, 0.4, 0.6, 0.8, 1.0", | |
info="Values ranging from 0 to 1.0, separated by a comma.", | |
interactive=True, | |
) | |
vect__min_df = gr.Textbox( | |
label="Vectorizer min_df (vect__min_df)", | |
value="1, 3, 5, 10", | |
info="Values ranging from 0 to 1.0, separated by a comma, or integers separated by a comma. If float, the parameter represents a proportion of documents, integer absolute counts.", | |
interactive=True, | |
) | |
with gr.Column(): | |
vect__ngram_range = gr.Textbox( | |
label="Vectorizer ngram_range (vect__ngram_range)", | |
value="(1, 1), (1, 2)", | |
info="""Tuples of integer values separated by a comma. For example an `ngram_range` of `(1, 1)` means only unigrams, `(1, 2)` means unigrams and bigrams, and `(2, 2)` means only bigrams.""", | |
interactive=True, | |
) | |
vect__norm = gr.Textbox( | |
label="Vectorizer norm (vect__norm)", | |
value="l1, l2", | |
info="'l1' or 'l2', separated by a comma", | |
interactive=True, | |
) | |
with gr.Tab("DESCRIPTION OF PARAMETERS"): | |
gr.Markdown("""### Classifier Alpha""") | |
gr.Markdown(load_description("parameter_grid/alpha")) | |
gr.Markdown("""### Vectorizer max_df""") | |
gr.Markdown(load_description("parameter_grid/max_df")) | |
gr.Markdown("""### Vectorizer min_df""") | |
gr.Markdown(load_description("parameter_grid/min_df")) | |
gr.Markdown("""### Vectorizer ngram_range""") | |
gr.Markdown(load_description("parameter_grid/ngram_range")) | |
gr.Markdown("""### Vectorizer norm""") | |
gr.Markdown(load_description("parameter_grid/norm")) | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
## MODEL PIPELINE | |
```python | |
pipeline = Pipeline( | |
[ | |
("vect", TfidfVectorizer()), | |
("clf", ComplementNB()), | |
] | |
) | |
``` | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("""## TRAINING""") | |
with gr.Row(): | |
brn_train = gr.Button("Train").style(container=False) | |
gr.Markdown("## RESULTS") | |
with gr.Row(): | |
best_parameters = gr.Textbox(label="Best parameters") | |
test_accuracy = gr.Textbox(label="Test accuracy") | |
plot_trade = gr.Plot(label="") | |
plot_coordinates = gr.Plot(label="") | |
brn_train.click( | |
train_model, | |
inputs=[ | |
drop_categories, | |
vect__max_df, | |
vect__min_df, | |
vect__ngram_range, | |
vect__norm, | |
], | |
outputs=[plot_trade, plot_coordinates, best_parameters, test_accuracy], | |
) | |
app.launch() | |