refactor: base parser interface
Browse files- llmdataparser/base_parser.py +189 -26
llmdataparser/base_parser.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
from dataclasses import dataclass
|
3 |
from functools import lru_cache
|
4 |
-
from typing import Any, Generic, TypeVar
|
5 |
|
6 |
import datasets
|
7 |
|
@@ -9,12 +9,17 @@ import datasets
|
|
9 |
T = TypeVar("T", bound="ParseEntry")
|
10 |
|
11 |
|
12 |
-
@dataclass(frozen=True)
|
13 |
class ParseEntry:
|
14 |
"""A simple base class for entries, customizable by each dataset parser."""
|
15 |
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
|
|
18 |
"""
|
19 |
Abstract base class defining the interface for all dataset parsers.
|
20 |
"""
|
@@ -39,40 +44,178 @@ class DatasetParser(ABC, Generic[T]):
|
|
39 |
return self._parsed_data
|
40 |
|
41 |
@abstractmethod
|
42 |
-
def process_entry(
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
|
46 |
-
# Base class for Hugging Face datasets
|
47 |
class HuggingFaceDatasetParser(DatasetParser[T]):
|
48 |
"""
|
49 |
Base class for parsers that use datasets from Hugging Face.
|
50 |
"""
|
51 |
|
52 |
-
_data_source
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
def __init__(self):
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
return self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
@staticmethod
|
63 |
@lru_cache(maxsize=3)
|
64 |
def load_dataset_cached(
|
65 |
-
data_source: str,
|
66 |
):
|
67 |
"""
|
68 |
Cached static method to load a dataset from Hugging Face.
|
69 |
"""
|
70 |
-
return datasets.load_dataset(data_source,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def load(
|
73 |
self,
|
74 |
-
|
75 |
-
config_name: str = "all",
|
76 |
trust_remote_code: bool = True,
|
77 |
split: str | None = None,
|
78 |
**kwargs: Any,
|
@@ -80,21 +223,41 @@ class HuggingFaceDatasetParser(DatasetParser[T]):
|
|
80 |
"""
|
81 |
Load the dataset using the Hugging Face datasets library.
|
82 |
"""
|
83 |
-
#
|
84 |
-
|
85 |
-
if not data_source:
|
86 |
-
raise ValueError("The 'data_source' class variable must be defined.")
|
87 |
|
88 |
# Call the cached static method
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
trust_remote_code=trust_remote_code,
|
93 |
split=split,
|
94 |
**kwargs,
|
95 |
)
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
print(
|
98 |
-
f"Loaded dataset with {len(self.
|
99 |
)
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
from dataclasses import dataclass
|
3 |
from functools import lru_cache
|
4 |
+
from typing import Any, ClassVar, Generic, TypeVar
|
5 |
|
6 |
import datasets
|
7 |
|
|
|
9 |
T = TypeVar("T", bound="ParseEntry")
|
10 |
|
11 |
|
12 |
+
@dataclass(frozen=True, kw_only=True, slots=True)
|
13 |
class ParseEntry:
|
14 |
"""A simple base class for entries, customizable by each dataset parser."""
|
15 |
|
16 |
+
prompt: str
|
17 |
+
answer: str
|
18 |
+
raw_question: str
|
19 |
+
raw_answer: str
|
20 |
|
21 |
+
|
22 |
+
class DatasetParser(Generic[T], ABC):
|
23 |
"""
|
24 |
Abstract base class defining the interface for all dataset parsers.
|
25 |
"""
|
|
|
44 |
return self._parsed_data
|
45 |
|
46 |
@abstractmethod
|
47 |
+
def process_entry(
|
48 |
+
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
|
49 |
+
) -> T:
|
50 |
+
"""
|
51 |
+
Process a single entry from the dataset.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
row: A dictionary representing a single entry from the dataset.
|
55 |
+
task_name: Optional task name for the entry.
|
56 |
+
**kwargs: Additional keyword arguments.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
T: The processed entry, typically an instance of a subclass of ParseEntry.
|
60 |
+
"""
|
61 |
+
|
62 |
+
|
63 |
+
@dataclass(frozen=True, kw_only=True, slots=True)
|
64 |
+
class HuggingFaceParseEntry(ParseEntry):
|
65 |
+
"""ParseEntry with an additional task_name field."""
|
66 |
+
|
67 |
+
task_name: str
|
68 |
|
69 |
|
|
|
70 |
class HuggingFaceDatasetParser(DatasetParser[T]):
|
71 |
"""
|
72 |
Base class for parsers that use datasets from Hugging Face.
|
73 |
"""
|
74 |
|
75 |
+
# _data_source is the name of the dataset, e.g. "lighteval/MATH"
|
76 |
+
_data_source: ClassVar[str]
|
77 |
+
# _task_names is the list of tasks in the dataset, e.g. ["algebra", "geometry", "statistics"]
|
78 |
+
_task_names: ClassVar[list[str]]
|
79 |
+
# _default_task is the default task to use if no task is specified, e.g. "algebra"
|
80 |
+
_default_task: ClassVar[str]
|
81 |
+
# _default_system_prompt is the default system prompt to use if no system prompt is specified
|
82 |
+
_default_system_prompt: ClassVar[str]
|
83 |
|
84 |
+
def __init__(self, system_prompt: str | None = None, **kwargs):
|
85 |
+
"""
|
86 |
+
Initialize a HuggingFaceDatasetParser.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
system_prompt: Optional custom system prompt to use instead of the default.
|
90 |
+
If not provided, will use the class's _default_system_prompt.
|
91 |
+
**kwargs: Additional keyword arguments passed to the parent class.
|
92 |
+
"""
|
93 |
super().__init__()
|
94 |
+
# raw_data is the dataset loaded from HuggingFace
|
95 |
+
self.raw_data: dict[str, Any] | None = None
|
96 |
+
# split_names is the list of splits in the dataset, e.g. ["train", "test", "validation"]
|
97 |
+
self.split_names: list[str] = []
|
98 |
+
# _current_task is the task currently being processed, e.g. "algebra"
|
99 |
+
self._current_task: str = ""
|
100 |
+
# _system_prompt is the system prompt currently being used
|
101 |
+
self._system_prompt: str = system_prompt or self._default_system_prompt
|
102 |
+
|
103 |
+
def _get_current_task(self, data_entry: dict[str, Any] | None = None) -> str:
|
104 |
+
"""
|
105 |
+
Get the currently loaded task name.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
data_entry: Optional dictionary containing entry data that might include task information
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
str: The task name from either the data entry (if available) or the currently set task
|
112 |
+
"""
|
113 |
+
# If data_entry is provided and contains task information, use it
|
114 |
+
if data_entry is not None and hasattr(self, "_get_task_from_entry"):
|
115 |
+
try:
|
116 |
+
return self._get_task_from_entry(data_entry)
|
117 |
+
except (KeyError, AttributeError):
|
118 |
+
pass
|
119 |
|
120 |
+
# Otherwise return the task set during load()
|
121 |
+
return self._current_task or self._default_task
|
122 |
+
|
123 |
+
@property
|
124 |
+
def task_names(self) -> list[str]:
|
125 |
+
"""Get all available task names."""
|
126 |
+
return self._task_names
|
127 |
+
|
128 |
+
@property
|
129 |
+
def total_tasks(self) -> int:
|
130 |
+
"""Get total number of available tasks."""
|
131 |
+
return len(self._task_names)
|
132 |
+
|
133 |
+
@property
|
134 |
+
def get_huggingface_link(self) -> str:
|
135 |
+
return "https://huggingface.co/datasets/" + self._data_source
|
136 |
|
137 |
@staticmethod
|
138 |
@lru_cache(maxsize=3)
|
139 |
def load_dataset_cached(
|
140 |
+
data_source: str, task_name: str = "default", **kwargs: Any
|
141 |
):
|
142 |
"""
|
143 |
Cached static method to load a dataset from Hugging Face.
|
144 |
"""
|
145 |
+
return datasets.load_dataset(data_source, task_name, **kwargs)
|
146 |
+
|
147 |
+
def parse(
|
148 |
+
self,
|
149 |
+
split_names: str | list[str] | None = None,
|
150 |
+
force: bool = False,
|
151 |
+
**kwargs: Any,
|
152 |
+
) -> None:
|
153 |
+
"""
|
154 |
+
Parse the MATH dataset splits into structured entries.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
split_names: Dataset splits to parse. Can be:
|
158 |
+
- None: Parse all available splits
|
159 |
+
- str: Parse a single split (e.g., "train")
|
160 |
+
- list[str]: Parse multiple splits (e.g., ["train", "test"])
|
161 |
+
force: If True, overwrites existing parsed data without confirmation.
|
162 |
+
If False and parsed data exists, prompts for confirmation.
|
163 |
+
**kwargs: Additional keyword arguments passed to process_entry
|
164 |
+
|
165 |
+
Raises:
|
166 |
+
ValueError: If no data is loaded or if a specified split name doesn't exist
|
167 |
+
"""
|
168 |
+
if self.raw_data is None:
|
169 |
+
raise ValueError("No data loaded. Please load the dataset first.")
|
170 |
+
|
171 |
+
if self._parsed_data and not force:
|
172 |
+
response = input(
|
173 |
+
f"Found {len(self._parsed_data)} existing parsed entries. "
|
174 |
+
"Do you want to overwrite them? [y/N]: "
|
175 |
+
).lower()
|
176 |
+
if response not in ("y", "yes"):
|
177 |
+
print("Parsing cancelled. Existing data preserved.")
|
178 |
+
return
|
179 |
+
|
180 |
+
self._parsed_data.clear()
|
181 |
+
|
182 |
+
# Dataset with splits
|
183 |
+
if split_names is None:
|
184 |
+
split_names = self.split_names
|
185 |
+
elif isinstance(split_names, str):
|
186 |
+
split_names = [split_names]
|
187 |
+
|
188 |
+
for split_name in split_names:
|
189 |
+
if split_name not in self.split_names:
|
190 |
+
raise ValueError(f"Split '{split_name}' not found in the dataset.")
|
191 |
+
|
192 |
+
dataset_split = self.raw_data[split_name]
|
193 |
+
total_entries = len(dataset_split)
|
194 |
+
print(f"Processing {split_name} split with {total_entries} entries...")
|
195 |
+
|
196 |
+
for index, entry in enumerate(dataset_split, start=1):
|
197 |
+
try:
|
198 |
+
task_name = self._get_current_task(data_entry=entry)
|
199 |
+
parsed_entry = self.process_entry(entry, task_name, **kwargs)
|
200 |
+
self._parsed_data.append(parsed_entry)
|
201 |
+
|
202 |
+
# Print progress every 100 entries
|
203 |
+
if index % 100 == 0:
|
204 |
+
print(
|
205 |
+
f"Processed {index}/{total_entries} entries from '{split_name}'"
|
206 |
+
)
|
207 |
+
|
208 |
+
except Exception as e:
|
209 |
+
print(f"Error processing entry {index} in {split_name}: {str(e)}")
|
210 |
+
continue
|
211 |
+
|
212 |
+
print(f"Completed parsing {index} entries from '{split_name}'")
|
213 |
+
|
214 |
+
print(f"Total parsed entries: {len(self._parsed_data)}")
|
215 |
|
216 |
def load(
|
217 |
self,
|
218 |
+
task_name: str | None = None,
|
|
|
219 |
trust_remote_code: bool = True,
|
220 |
split: str | None = None,
|
221 |
**kwargs: Any,
|
|
|
223 |
"""
|
224 |
Load the dataset using the Hugging Face datasets library.
|
225 |
"""
|
226 |
+
# Set the task name
|
227 |
+
self._current_task = task_name or self._default_task
|
|
|
|
|
228 |
|
229 |
# Call the cached static method
|
230 |
+
raw_data = self.load_dataset_cached(
|
231 |
+
self._data_source,
|
232 |
+
task_name=self._current_task,
|
233 |
trust_remote_code=trust_remote_code,
|
234 |
split=split,
|
235 |
**kwargs,
|
236 |
)
|
237 |
+
|
238 |
+
# Handle split-specific loading
|
239 |
+
if split:
|
240 |
+
self.raw_data = {split: raw_data}
|
241 |
+
self.split_names = [split]
|
242 |
+
else:
|
243 |
+
self.raw_data = raw_data
|
244 |
+
self.split_names = list(raw_data.keys())
|
245 |
+
|
246 |
print(
|
247 |
+
f"Loaded dataset with {len(self.split_names)} groups: {', '.join(self.split_names)}."
|
248 |
)
|
249 |
+
|
250 |
+
def __repr__(self) -> str:
|
251 |
+
status = "loaded" if self.raw_data is not None else "not loaded"
|
252 |
+
parsed_count = len(self._parsed_data) if self._parsed_data else 0
|
253 |
+
return (
|
254 |
+
f"{self.__class__.__name__}("
|
255 |
+
f"data_source='{self._data_source}', "
|
256 |
+
f"task='{self._current_task}', "
|
257 |
+
f"status='{status}', "
|
258 |
+
f"parsed_entries={parsed_count}"
|
259 |
+
")"
|
260 |
+
)
|
261 |
+
|
262 |
+
def __str__(self) -> str:
|
263 |
+
return self.__repr__()
|