llm-bradley-terry / hdinterval.py
jerome-white's picture
Signal when target is not in HDI
2673c60
import math
import warnings
import operator as op
import itertools as it
import functools as ft
import statistics as st
from dataclasses import dataclass
@dataclass
class HDI:
lower: float
upper: float
def __iter__(self):
yield from (self.lower, self.upper)
def __contains__(self, item):
return self.lower <= item <= self.upper
def width(self):
return self.upper - self.lower
class HDInterval:
@ft.cached_property
def values(self):
view = sorted(filter(math.isfinite, self._values))
if not view:
raise AttributeError('Empty data set')
return view
def __init__(self, values):
self._values = values
#
# See https://cran.r-project.org/package=HDInterval
#
def __call__(self, ci=0.95):
if ci == 1:
args = (self.values[x] for x in (0, -1))
else:
n = len(self.values)
exclude = n - math.floor(n * ci)
left = it.islice(self.values, exclude)
right = it.islice(self.values, n - exclude, None)
diffs = ((x, y, y - x) for (x, y) in zip(left, right))
(*args, _) = min(diffs, key=op.itemgetter(-1))
return HDI(*args)
def _at(self, target, tolerance, ci=1, jump=1):
if ci > 1:
raise OverflowError()
hdi = self(ci)
if any(x in tolerance for x in hdi):
return ci
adjust = op.sub if target in hdi else op.add
ci = adjust(ci, jump)
jump /= 2
return self._at(target, tolerance, ci, jump)
def at(self, target, tolerance=1e-4):
assert tolerance > 0
while tolerance < 1:
hdi = HDI(target, target + tolerance)
try:
return self._at(target, hdi)
except RecursionError:
tolerance *= 10
warnings.warn(f'Tolerance reduced: {tolerance}')
raise FloatingPointError('Unable to converge')