|
from abc import ABC, abstractmethod |
|
from dataclasses import dataclass |
|
from functools import lru_cache |
|
from typing import Any, Generic, TypeVar |
|
|
|
import datasets |
|
|
|
|
|
T = TypeVar("T", bound="ParseEntry") |
|
|
|
|
|
@dataclass(frozen=True) |
|
class ParseEntry: |
|
"""A simple base class for entries, customizable by each dataset parser.""" |
|
|
|
|
|
class DatasetParser(ABC, Generic[T]): |
|
""" |
|
Abstract base class defining the interface for all dataset parsers. |
|
""" |
|
|
|
def __init__(self): |
|
self._parsed_data: list[T] = [] |
|
|
|
@abstractmethod |
|
def load(self, **kwargs: Any) -> None: |
|
pass |
|
|
|
@abstractmethod |
|
def parse(self, split_names: str | list[str] | None = None, **kwargs: Any) -> None: |
|
""" |
|
Parse the loaded dataset into self._parsed_data. |
|
""" |
|
|
|
@property |
|
def get_parsed_data(self) -> list[T]: |
|
if not hasattr(self, "_parsed_data") or not self._parsed_data: |
|
raise ValueError("Parsed data has not been initialized.") |
|
return self._parsed_data |
|
|
|
@abstractmethod |
|
def process_entry(self, row: dict[str, Any]) -> T: |
|
pass |
|
|
|
|
|
|
|
class HuggingFaceDatasetParser(DatasetParser[T]): |
|
""" |
|
Base class for parsers that use datasets from Hugging Face. |
|
""" |
|
|
|
_data_source: str |
|
|
|
def __init__(self): |
|
self.raw_data = None |
|
self.task_names = [] |
|
super().__init__() |
|
|
|
def get_task_names(self) -> list[str]: |
|
return self.task_names |
|
|
|
@staticmethod |
|
@lru_cache(maxsize=3) |
|
def load_dataset_cached( |
|
data_source: str, config_name: str = "default", **kwargs: Any |
|
): |
|
""" |
|
Cached static method to load a dataset from Hugging Face. |
|
""" |
|
return datasets.load_dataset(data_source, config_name, **kwargs) |
|
|
|
def load( |
|
self, |
|
data_source: str | None = None, |
|
config_name: str = "all", |
|
trust_remote_code: bool = True, |
|
split: str | None = None, |
|
**kwargs: Any, |
|
) -> None: |
|
""" |
|
Load the dataset using the Hugging Face datasets library. |
|
""" |
|
|
|
data_source = data_source or self._data_source |
|
if not data_source: |
|
raise ValueError("The 'data_source' class variable must be defined.") |
|
|
|
|
|
self.raw_data = self.load_dataset_cached( |
|
data_source, |
|
config_name=config_name, |
|
trust_remote_code=trust_remote_code, |
|
split=split, |
|
**kwargs, |
|
) |
|
self.task_names = list(self.raw_data.keys()) |
|
print( |
|
f"Loaded dataset with {len(self.task_names)} tasks: {', '.join(self.task_names)}." |
|
) |
|
|
|
|