JeffYang52415's picture
feat: update base_parser
2822485 unverified
raw
history blame
2.89 kB
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Generic, TypeVar
import datasets
# Define the generic type variable
T = TypeVar("T", bound="ParseEntry")
@dataclass(frozen=True)
class ParseEntry:
"""A simple base class for entries, customizable by each dataset parser."""
class DatasetParser(ABC, Generic[T]):
"""
Abstract base class defining the interface for all dataset parsers.
"""
def __init__(self):
self._parsed_data: list[T] = []
@abstractmethod
def load(self, **kwargs: Any) -> None:
pass
@abstractmethod
def parse(self, split_names: str | list[str] | None = None, **kwargs: Any) -> None:
"""
Parse the loaded dataset into self._parsed_data.
"""
@property
def get_parsed_data(self) -> list[T]:
if not hasattr(self, "_parsed_data") or not self._parsed_data:
raise ValueError("Parsed data has not been initialized.")
return self._parsed_data
@abstractmethod
def process_entry(self, row: dict[str, Any]) -> T:
pass
# Base class for Hugging Face datasets
class HuggingFaceDatasetParser(DatasetParser[T]):
"""
Base class for parsers that use datasets from Hugging Face.
"""
_data_source: str # Class variable for the dataset name
def __init__(self):
self.raw_data = None
self.task_names = []
super().__init__()
def get_task_names(self) -> list[str]:
return self.task_names
@staticmethod
@lru_cache(maxsize=3)
def load_dataset_cached(
data_source: str, config_name: str = "default", **kwargs: Any
):
"""
Cached static method to load a dataset from Hugging Face.
"""
return datasets.load_dataset(data_source, config_name, **kwargs)
def load(
self,
data_source: str | None = None,
config_name: str = "all",
trust_remote_code: bool = True,
split: str | None = None,
**kwargs: Any,
) -> None:
"""
Load the dataset using the Hugging Face datasets library.
"""
# Use class-level data_source if not provided
data_source = data_source or self._data_source
if not data_source:
raise ValueError("The 'data_source' class variable must be defined.")
# Call the cached static method
self.raw_data = self.load_dataset_cached(
data_source,
config_name=config_name,
trust_remote_code=trust_remote_code,
split=split,
**kwargs,
)
self.task_names = list(self.raw_data.keys())
print(
f"Loaded dataset with {len(self.task_names)} tasks: {', '.join(self.task_names)}."
)
# Additional common initialization can be added here