Spaces:
Sleeping
Sleeping
import math | |
import operator as op | |
import itertools as it | |
import functools as ft | |
import collections as cl | |
from pathlib import Path | |
from dataclasses import fields, asdict | |
import pandas as pd | |
import gradio as gr | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from datasets import load_dataset | |
from scipy.special import expit | |
from matplotlib.ticker import FixedLocator, StrMethodFormatter | |
from hdinterval import HDI, HDInterval | |
TabGroup = cl.namedtuple('TabGroup', 'name, docs, dataset') | |
# | |
# | |
# | |
def load(repo): | |
parameter = 'parameter' | |
model = 'model' | |
items = [ | |
'chain', | |
'sample', | |
parameter, | |
model, | |
'value', | |
] | |
dataset = load_dataset(str(repo)) | |
return (dataset | |
.get('train') | |
.to_pandas() | |
.rename(columns={'element': model}) | |
.filter(items=items) | |
.query(f'{parameter} == "alpha"') | |
.drop(columns=parameter)) | |
def summarize(df, ci): | |
def _aggregate(i, g): | |
values = g['value'] | |
hdi = HDInterval(values) | |
interval = hdi(ci) | |
agg = { | |
'model': i, | |
'ability': values.median(), | |
'uncertainty': interval.width(), | |
} | |
agg.update(asdict(interval)) | |
return agg | |
groups = df.groupby('model', sort=False) | |
records = it.starmap(_aggregate, groups) | |
return pd.DataFrame.from_records(records) | |
def rank(df, ascending, name='rank'): | |
df = (df | |
.sort_values(by=['ability', 'uncertainty'], | |
ascending=[ascending, not ascending]) | |
.drop(columns='uncertainty') | |
.reset_index(drop=True)) | |
df.index += 1 | |
return df.reset_index(names=name) | |
def compare(df, model_1, model_2): | |
mcol = 'model' | |
models = [ | |
model_1, | |
model_2, | |
] | |
view = (df | |
.query(f'{mcol} in @models') | |
.pivot(index=['chain', 'sample'], | |
columns=mcol, | |
values='value')) | |
return expit(view[model_1] - view[model_2]) | |
# | |
# | |
# | |
class DataPlotter: | |
def __init__(self, df): | |
self.df = df | |
def plot(self): | |
fig = plt.figure(dpi=200) | |
ax = fig.gca() | |
self.draw(ax) | |
ax.grid(visible=True, | |
axis='both', | |
alpha=0.25, | |
linestyle='dotted') | |
fig.tight_layout() | |
return fig | |
def draw(self, ax): | |
raise NotImplementedError() | |
class RankPlotter(DataPlotter): | |
_y = 'y' | |
def y(self): | |
return self.df[self._y] | |
def __init__(self, df, ci=0.95, top=10): | |
self.ci = ci | |
view = rank(summarize(df, self.ci), True, self._y) | |
view = (view | |
.tail(top) | |
.sort_values(by=self._y, ascending=False)) | |
super().__init__(view) | |
def draw(self, ax): | |
self.df.plot.scatter('ability', self._y, ax=ax) | |
ax.hlines(self.y, | |
xmin=self.df['lower'], | |
xmax=self.df['upper'], | |
alpha=0.5) | |
ax.set_xlabel('{} (with {:.0%} HDI)'.format( | |
ax.get_xlabel().title(), | |
self.ci, | |
)) | |
ax.set_ylabel('') | |
ax.set_yticks(self.y, self.df['model']) | |
class ComparisonPlotter(DataPlotter): | |
_uncertain = 0.5 | |
def __init__(self, df, model_1, model_2, ci): | |
super().__init__(compare(df, model_1, model_2)) | |
self.interval = HDInterval(self.df) | |
self.ci = ci | |
def draw(self, ax): | |
hdi = self.interval(self.ci) | |
(c_hist, c_hdi) = sns.color_palette('colorblind', n_colors=2) | |
ax = sns.histplot(data=self.df, | |
stat='density', | |
color=c_hist) | |
ax.set_xlabel('\u03B1$_{1}$ - \u03B1$_{2}$') | |
self.pr(ax, hdi, c_hdi) | |
self.min_inclusive(ax) | |
def min_inclusive(self, ax): | |
try: | |
ci = self.interval.at(self._uncertain) | |
inclusive = '\u2208' | |
except OverflowError: | |
ci = 1 | |
inclusive = '\u2209' | |
except FloatingPointError: | |
return | |
ax.text(x=0.02, | |
y=0.975, | |
s=f'{self._uncertain} {inclusive} {ci:.0%} HDI', | |
fontsize='small', | |
fontstyle='italic', | |
horizontalalignment='left', | |
verticalalignment='top', | |
transform=ax.transAxes) | |
def pr(self, ax, hdi, color): | |
x = self.df.median() | |
zorder = ax.zorder - 1 | |
(label, *_) = ax.get_xticklabels() | |
parts = label.get_text().split('.') | |
decimals = len(parts[-1]) + 1 if parts else 2 | |
fmt = f'Pr(M$_{{{{1}}}}$ \u003E M$_{{{{2}}}}$) = {{x:.{decimals}f}}' | |
ax.axvline(x=x, | |
color=color, | |
linestyle='dashed') | |
ax.axvspan(xmin=hdi.lower, | |
xmax=hdi.upper, | |
alpha=0.15, | |
color=color, | |
zorder=zorder) | |
ax_ = ax.secondary_xaxis('top') | |
ax_.xaxis.set_major_locator(FixedLocator([x])) | |
ax_.xaxis.set_major_formatter(StrMethodFormatter(fmt)) | |
# | |
# | |
# | |
class ComparisonMenu: | |
def __init__(self, df, ci=95): | |
self.df = df | |
self.ci = ci | |
def __call__(self, model_1, model_2, ci): | |
if model_1 and model_2: | |
ci /= 100 | |
cp = ComparisonPlotter(self.df, model_1, model_2, ci) | |
return cp.plot() | |
def build_and_get(self): | |
models = self.df['model'].unique() | |
choices = sorted(models, key=lambda x: x.lower()) | |
for i in range(1, 3): | |
label = f'Model {i}' | |
yield gr.Dropdown(label=label, choices=choices) | |
yield gr.Number(value=self.ci, | |
label='HDI (%)', | |
minimum=0, | |
maximum=100) | |
# | |
# | |
# | |
class DocumentationReader: | |
_suffix = '.md' | |
def __init__(self, root): | |
self.root = root | |
def __getitem__(self, item): | |
return (self | |
.root | |
.joinpath(item) | |
.with_suffix(self._suffix) | |
.read_text()) | |
# | |
# | |
# | |
def layout(tab, ci=0.95): | |
df = load(Path('jerome-white', tab.dataset)) | |
docs = DocumentationReader(Path('docs', t.docs)) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(docs['readme']) | |
with gr.Column(): | |
plotter = RankPlotter(df, ci) | |
gr.Plot(plotter.plot()) | |
with gr.Row(): | |
view = rank(summarize(df, ci), False) | |
columns = { x.name: f'{ci:.0%} HDI {x.name}' for x in fields(HDI) } | |
for i in view.columns: | |
columns.setdefault(i, i.title()) | |
view = (view | |
.rename(columns=columns) | |
.style.format(precision=4)) | |
gr.Dataframe(view) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
display = gr.Plot() | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(f''' | |
Probability that Model 1 is preferred to Model 2. The | |
histogram is represents the distribution of the | |
difference in estimated model abilities. The dashed | |
vertical line is its median. The shaded region | |
demarcates the chosen [highest density | |
interval](https://cran.r-project.org/package=HDInterval) | |
(HDI). The note in the upper left denotes the smallest | |
HDI that is inclusive of | |
{ComparisonPlotter._uncertain}. | |
''') | |
with gr.Column(): | |
menu = ComparisonMenu(df) | |
inputs = list(menu.build_and_get()) | |
button = gr.Button(value='Compare!') | |
button.click(menu, inputs=inputs, outputs=[display]) | |
with gr.Accordion('Disclaimer', open=False): | |
gr.Markdown(docs['disclaimer']) | |
# | |
# | |
# | |
with gr.Blocks() as demo: | |
tabs = it.starmap(TabGroup, ( | |
('Chatbot Arena', 'arena', 'arena-bt-stan'), | |
('Alpaca', 'alpaca', 'alpaca-bt-stan'), | |
)) | |
for t in tabs: | |
with gr.Tab(t.name): | |
layout(t) | |
demo.launch() | |