JeffYang52415 commited on
Commit
2822485
·
unverified ·
1 Parent(s): f835380

feat: update base_parser

Browse files
.pre-commit-config.yaml CHANGED
@@ -12,6 +12,7 @@ repos:
12
  hooks:
13
  - id: flake8
14
  additional_dependencies: ["typing-extensions>=4.8.0"]
 
15
  - repo: https://github.com/PyCQA/isort
16
  rev: 5.12.0
17
  hooks:
@@ -21,7 +22,13 @@ repos:
21
  rev: v1.5.1
22
  hooks:
23
  - id: mypy
24
- args: ["--python-version=3.11", "--install-types", "--non-interactive"]
 
 
 
 
 
 
25
  additional_dependencies:
26
  - "typing-extensions>=4.8.0"
27
  - repo: https://github.com/pre-commit/pre-commit-hooks
 
12
  hooks:
13
  - id: flake8
14
  additional_dependencies: ["typing-extensions>=4.8.0"]
15
+ args: ["--ignore=E203, E501, W503, E501"]
16
  - repo: https://github.com/PyCQA/isort
17
  rev: 5.12.0
18
  hooks:
 
22
  rev: v1.5.1
23
  hooks:
24
  - id: mypy
25
+ args:
26
+ [
27
+ "--python-version=3.11",
28
+ "--install-types",
29
+ "--non-interactive",
30
+ "--ignore-missing-imports",
31
+ ]
32
  additional_dependencies:
33
  - "typing-extensions>=4.8.0"
34
  - repo: https://github.com/pre-commit/pre-commit-hooks
llmdataparser/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llmdataparser/__init__.py
2
+ from typing import Type
3
+
4
+ from .base_parser import DatasetParser
5
+ from .mmlu_parser import MMLUDatasetParser
6
+
7
+
8
+ class ParserRegistry:
9
+ """
10
+ Registry to keep track of available parsers and provide them on request.
11
+ """
12
+
13
+ _registry: dict = {}
14
+
15
+ @classmethod
16
+ def register_parser(cls, name: str, parser_class: Type[DatasetParser]) -> None:
17
+ cls._registry[name.lower()] = parser_class
18
+
19
+ @classmethod
20
+ def get_parser(cls, name: str, **kwargs) -> Type[DatasetParser]:
21
+ parser_class = cls._registry.get(name.lower())
22
+ if parser_class is None:
23
+ raise ValueError(f"Parser '{name}' is not registered.")
24
+ return parser_class(**kwargs)
25
+
26
+ @classmethod
27
+ def list_parsers(cls) -> list[str]:
28
+ """Returns a list of available parser names."""
29
+ return list(cls._registry.keys())
30
+
31
+
32
+ # Register parsers
33
+ ParserRegistry.register_parser("mmlu", MMLUDatasetParser)
llmdataparser/base_parser.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from functools import lru_cache
4
+ from typing import Any, Generic, TypeVar
5
+
6
+ import datasets
7
+
8
+ # Define the generic type variable
9
+ T = TypeVar("T", bound="ParseEntry")
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class ParseEntry:
14
+ """A simple base class for entries, customizable by each dataset parser."""
15
+
16
+
17
+ class DatasetParser(ABC, Generic[T]):
18
+ """
19
+ Abstract base class defining the interface for all dataset parsers.
20
+ """
21
+
22
+ def __init__(self):
23
+ self._parsed_data: list[T] = []
24
+
25
+ @abstractmethod
26
+ def load(self, **kwargs: Any) -> None:
27
+ pass
28
+
29
+ @abstractmethod
30
+ def parse(self, split_names: str | list[str] | None = None, **kwargs: Any) -> None:
31
+ """
32
+ Parse the loaded dataset into self._parsed_data.
33
+ """
34
+
35
+ @property
36
+ def get_parsed_data(self) -> list[T]:
37
+ if not hasattr(self, "_parsed_data") or not self._parsed_data:
38
+ raise ValueError("Parsed data has not been initialized.")
39
+ return self._parsed_data
40
+
41
+ @abstractmethod
42
+ def process_entry(self, row: dict[str, Any]) -> T:
43
+ pass
44
+
45
+
46
+ # Base class for Hugging Face datasets
47
+ class HuggingFaceDatasetParser(DatasetParser[T]):
48
+ """
49
+ Base class for parsers that use datasets from Hugging Face.
50
+ """
51
+
52
+ _data_source: str # Class variable for the dataset name
53
+
54
+ def __init__(self):
55
+ self.raw_data = None
56
+ self.task_names = []
57
+ super().__init__()
58
+
59
+ def get_task_names(self) -> list[str]:
60
+ return self.task_names
61
+
62
+ @staticmethod
63
+ @lru_cache(maxsize=3)
64
+ def load_dataset_cached(
65
+ data_source: str, config_name: str = "default", **kwargs: Any
66
+ ):
67
+ """
68
+ Cached static method to load a dataset from Hugging Face.
69
+ """
70
+ return datasets.load_dataset(data_source, config_name, **kwargs)
71
+
72
+ def load(
73
+ self,
74
+ data_source: str | None = None,
75
+ config_name: str = "all",
76
+ trust_remote_code: bool = True,
77
+ split: str | None = None,
78
+ **kwargs: Any,
79
+ ) -> None:
80
+ """
81
+ Load the dataset using the Hugging Face datasets library.
82
+ """
83
+ # Use class-level data_source if not provided
84
+ data_source = data_source or self._data_source
85
+ if not data_source:
86
+ raise ValueError("The 'data_source' class variable must be defined.")
87
+
88
+ # Call the cached static method
89
+ self.raw_data = self.load_dataset_cached(
90
+ data_source,
91
+ config_name=config_name,
92
+ trust_remote_code=trust_remote_code,
93
+ split=split,
94
+ **kwargs,
95
+ )
96
+ self.task_names = list(self.raw_data.keys())
97
+ print(
98
+ f"Loaded dataset with {len(self.task_names)} tasks: {', '.join(self.task_names)}."
99
+ )
100
+ # Additional common initialization can be added here
pyproject.toml CHANGED
@@ -49,11 +49,13 @@ profile = "black"
49
  line_length = 88
50
  known_first_party = ["llmdataparser"]
51
 
 
52
  [tool.flake8]
53
- max-line-length = 88
54
- ignore = [
55
- "E501" # Line too long
56
  ]
 
57
 
58
  [tool.ruff]
59
  line-length = 88
 
49
  line_length = 88
50
  known_first_party = ["llmdataparser"]
51
 
52
+ # .flake8
53
  [tool.flake8]
54
+ ignore = ['E231', 'E241', "E501"]
55
+ per-file-ignores = [
56
+ '__init__.py:F401',
57
  ]
58
+ count = true
59
 
60
  [tool.ruff]
61
  line-length = 88