File size: 2,636 Bytes
fe70438 82055e6 d08fbc6 82055e6 d08fbc6 24df49f d08fbc6 82055e6 d08fbc6 fe70438 d08fbc6 24df49f d08fbc6 fe70438 d08fbc6 88c61d3 fe70438 d08fbc6 fe70438 d08fbc6 82055e6 d08fbc6 82055e6 d08fbc6 82055e6 d08fbc6 82055e6 d08fbc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from abc import abstractmethod
from typing import Dict, List, Optional, Union
from .dataclass import NonPositionalField
from .formats import Format
from .fusion import FixedFusion
from .operator import SourceOperator
from .standard import DatasetRecipe
from .stream import MultiStream
from .system_prompts import SystemPrompt
class BaseBenchmark(SourceOperator):
format: Format = NonPositionalField(default=None)
num_demos: int = NonPositionalField(default=None)
system_prompt: SystemPrompt = NonPositionalField(default=None)
loader_limit: int = NonPositionalField(default=None)
splits: List[str] = NonPositionalField(
default_factory=lambda: ["train", "validation", "test"]
)
subset: Optional[str] = NonPositionalField(default=None)
@abstractmethod
def reset(self):
pass
class Benchmark(BaseBenchmark):
subsets: Dict[str, Union[DatasetRecipe, BaseBenchmark]]
max_total_samples: int = None
max_samples_per_subset: int = None
def verify(self):
super().verify()
if (
self.max_total_samples is not None
and self.max_samples_per_subset is not None
):
raise ValueError("Set either max_total_samples or max_samples_per_subset")
def prepare_args(self):
self.subsets = dict(self.subsets)
def reset(self):
if (
self.format is not None
or self.num_demos is not None
or self.system_prompt is not None
or self.loader_limit is not None
):
for subset in self.subsets.values():
if self.num_demos is not None:
subset.num_demos = self.num_demos
if self.format is not None:
subset.format = self.format
if self.system_prompt is not None:
subset.system_prompt = self.system_prompt
if self.loader_limit is not None:
subset.loader_limit = self.loader_limit
subset.reset()
def prepare(self):
super().prepare()
self.reset()
def process(
self,
) -> MultiStream:
if self.subset is not None:
subsets = {self.subset: self.subsets[self.subset]}
else:
subsets = self.subsets
if self.max_total_samples is None:
operator = FixedFusion(
subsets=subsets,
max_instances_per_subset=self.max_samples_per_subset,
include_splits=self.splits,
)
else:
raise NotImplementedError()
return operator()
|