feat: update base_parser
Browse files- .pre-commit-config.yaml +8 -1
- llmdataparser/__init__.py +33 -0
- llmdataparser/base_parser.py +100 -0
- pyproject.toml +5 -3
.pre-commit-config.yaml
CHANGED
@@ -12,6 +12,7 @@ repos:
|
|
12 |
hooks:
|
13 |
- id: flake8
|
14 |
additional_dependencies: ["typing-extensions>=4.8.0"]
|
|
|
15 |
- repo: https://github.com/PyCQA/isort
|
16 |
rev: 5.12.0
|
17 |
hooks:
|
@@ -21,7 +22,13 @@ repos:
|
|
21 |
rev: v1.5.1
|
22 |
hooks:
|
23 |
- id: mypy
|
24 |
-
args:
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
additional_dependencies:
|
26 |
- "typing-extensions>=4.8.0"
|
27 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
|
12 |
hooks:
|
13 |
- id: flake8
|
14 |
additional_dependencies: ["typing-extensions>=4.8.0"]
|
15 |
+
args: ["--ignore=E203, E501, W503, E501"]
|
16 |
- repo: https://github.com/PyCQA/isort
|
17 |
rev: 5.12.0
|
18 |
hooks:
|
|
|
22 |
rev: v1.5.1
|
23 |
hooks:
|
24 |
- id: mypy
|
25 |
+
args:
|
26 |
+
[
|
27 |
+
"--python-version=3.11",
|
28 |
+
"--install-types",
|
29 |
+
"--non-interactive",
|
30 |
+
"--ignore-missing-imports",
|
31 |
+
]
|
32 |
additional_dependencies:
|
33 |
- "typing-extensions>=4.8.0"
|
34 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
llmdataparser/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# llmdataparser/__init__.py
|
2 |
+
from typing import Type
|
3 |
+
|
4 |
+
from .base_parser import DatasetParser
|
5 |
+
from .mmlu_parser import MMLUDatasetParser
|
6 |
+
|
7 |
+
|
8 |
+
class ParserRegistry:
|
9 |
+
"""
|
10 |
+
Registry to keep track of available parsers and provide them on request.
|
11 |
+
"""
|
12 |
+
|
13 |
+
_registry: dict = {}
|
14 |
+
|
15 |
+
@classmethod
|
16 |
+
def register_parser(cls, name: str, parser_class: Type[DatasetParser]) -> None:
|
17 |
+
cls._registry[name.lower()] = parser_class
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def get_parser(cls, name: str, **kwargs) -> Type[DatasetParser]:
|
21 |
+
parser_class = cls._registry.get(name.lower())
|
22 |
+
if parser_class is None:
|
23 |
+
raise ValueError(f"Parser '{name}' is not registered.")
|
24 |
+
return parser_class(**kwargs)
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def list_parsers(cls) -> list[str]:
|
28 |
+
"""Returns a list of available parser names."""
|
29 |
+
return list(cls._registry.keys())
|
30 |
+
|
31 |
+
|
32 |
+
# Register parsers
|
33 |
+
ParserRegistry.register_parser("mmlu", MMLUDatasetParser)
|
llmdataparser/base_parser.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import lru_cache
|
4 |
+
from typing import Any, Generic, TypeVar
|
5 |
+
|
6 |
+
import datasets
|
7 |
+
|
8 |
+
# Define the generic type variable
|
9 |
+
T = TypeVar("T", bound="ParseEntry")
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass(frozen=True)
|
13 |
+
class ParseEntry:
|
14 |
+
"""A simple base class for entries, customizable by each dataset parser."""
|
15 |
+
|
16 |
+
|
17 |
+
class DatasetParser(ABC, Generic[T]):
|
18 |
+
"""
|
19 |
+
Abstract base class defining the interface for all dataset parsers.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self):
|
23 |
+
self._parsed_data: list[T] = []
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def load(self, **kwargs: Any) -> None:
|
27 |
+
pass
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def parse(self, split_names: str | list[str] | None = None, **kwargs: Any) -> None:
|
31 |
+
"""
|
32 |
+
Parse the loaded dataset into self._parsed_data.
|
33 |
+
"""
|
34 |
+
|
35 |
+
@property
|
36 |
+
def get_parsed_data(self) -> list[T]:
|
37 |
+
if not hasattr(self, "_parsed_data") or not self._parsed_data:
|
38 |
+
raise ValueError("Parsed data has not been initialized.")
|
39 |
+
return self._parsed_data
|
40 |
+
|
41 |
+
@abstractmethod
|
42 |
+
def process_entry(self, row: dict[str, Any]) -> T:
|
43 |
+
pass
|
44 |
+
|
45 |
+
|
46 |
+
# Base class for Hugging Face datasets
|
47 |
+
class HuggingFaceDatasetParser(DatasetParser[T]):
|
48 |
+
"""
|
49 |
+
Base class for parsers that use datasets from Hugging Face.
|
50 |
+
"""
|
51 |
+
|
52 |
+
_data_source: str # Class variable for the dataset name
|
53 |
+
|
54 |
+
def __init__(self):
|
55 |
+
self.raw_data = None
|
56 |
+
self.task_names = []
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
def get_task_names(self) -> list[str]:
|
60 |
+
return self.task_names
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
@lru_cache(maxsize=3)
|
64 |
+
def load_dataset_cached(
|
65 |
+
data_source: str, config_name: str = "default", **kwargs: Any
|
66 |
+
):
|
67 |
+
"""
|
68 |
+
Cached static method to load a dataset from Hugging Face.
|
69 |
+
"""
|
70 |
+
return datasets.load_dataset(data_source, config_name, **kwargs)
|
71 |
+
|
72 |
+
def load(
|
73 |
+
self,
|
74 |
+
data_source: str | None = None,
|
75 |
+
config_name: str = "all",
|
76 |
+
trust_remote_code: bool = True,
|
77 |
+
split: str | None = None,
|
78 |
+
**kwargs: Any,
|
79 |
+
) -> None:
|
80 |
+
"""
|
81 |
+
Load the dataset using the Hugging Face datasets library.
|
82 |
+
"""
|
83 |
+
# Use class-level data_source if not provided
|
84 |
+
data_source = data_source or self._data_source
|
85 |
+
if not data_source:
|
86 |
+
raise ValueError("The 'data_source' class variable must be defined.")
|
87 |
+
|
88 |
+
# Call the cached static method
|
89 |
+
self.raw_data = self.load_dataset_cached(
|
90 |
+
data_source,
|
91 |
+
config_name=config_name,
|
92 |
+
trust_remote_code=trust_remote_code,
|
93 |
+
split=split,
|
94 |
+
**kwargs,
|
95 |
+
)
|
96 |
+
self.task_names = list(self.raw_data.keys())
|
97 |
+
print(
|
98 |
+
f"Loaded dataset with {len(self.task_names)} tasks: {', '.join(self.task_names)}."
|
99 |
+
)
|
100 |
+
# Additional common initialization can be added here
|
pyproject.toml
CHANGED
@@ -49,11 +49,13 @@ profile = "black"
|
|
49 |
line_length = 88
|
50 |
known_first_party = ["llmdataparser"]
|
51 |
|
|
|
52 |
[tool.flake8]
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
]
|
|
|
57 |
|
58 |
[tool.ruff]
|
59 |
line-length = 88
|
|
|
49 |
line_length = 88
|
50 |
known_first_party = ["llmdataparser"]
|
51 |
|
52 |
+
# .flake8
|
53 |
[tool.flake8]
|
54 |
+
ignore = ['E231', 'E241', "E501"]
|
55 |
+
per-file-ignores = [
|
56 |
+
'__init__.py:F401',
|
57 |
]
|
58 |
+
count = true
|
59 |
|
60 |
[tool.ruff]
|
61 |
line-length = 88
|