cmpatino's picture
Add ROC-AUC score for each feature
39ac7ff
raw
history blame
2.76 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
from datasets import load_dataset
import histos
dataset = load_dataset("cmpatino/optimal_observables", "train")
dataset_df = dataset["train"].to_pandas()
dataset_df["target"] = dataset_df["target"].map({0: "spin-OFF", 1: "spin-ON"})
def get_plot(features, n_bins):
plotting_df = dataset_df.copy()
if len(features) == 1:
fig, ax = plt.subplots()
pos_samples = plotting_df[plotting_df["target"] == "spin-ON"][features[0]]
neg_samples = plotting_df[plotting_df["target"] == "spin-OFF"][features[0]]
y_score = np.concatenate([pos_samples, neg_samples], axis=0)
if pos_samples.mean() >= neg_samples.mean():
y_true = np.concatenate(
[np.ones_like(pos_samples), np.zeros_like(neg_samples)], axis=0
)
roc_auc_score = metrics.roc_auc_score(y_true, y_score)
else:
y_true = np.concatenate(
[np.zeros_like(pos_samples), np.ones_like(neg_samples)], axis=0
)
roc_auc_score = metrics.roc_auc_score(y_true, y_score)
values = [
pos_samples,
neg_samples,
]
labels = ["spin-ON", "spin-OFF"]
fig = histos.ratio_hist(
processes_q=values,
hist_labels=labels,
reference_label=labels[1],
n_bins=n_bins,
hist_range=None,
title=f"{features[0]} (ROC AUC: {roc_auc_score:.3f})",
)
return fig
if len(features) == 2:
return sns.displot(
plotting_df,
x=features[0],
y=features[1],
hue="target",
bins=n_bins,
height=8,
aspect=1,
).fig
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
features = gr.Dropdown(
choices=dataset_df.columns.to_list(),
label="Feature",
value="m_tt",
multiselect=True,
)
n_bins = gr.Slider(
label="Number of Bins for Histogram",
value=10,
minimum=10,
maximum=100,
step=10,
)
feature_plot = gr.Plot(label="Feature's Plot")
features.change(
get_plot,
[features, n_bins],
feature_plot,
queue=False,
)
n_bins.change(
get_plot,
[features, n_bins],
feature_plot,
queue=False,
)
demo.load(
get_plot,
[features, n_bins],
feature_plot,
queue=False,
)
if __name__ == "__main__":
demo.launch()