JeffYang52415 commited on
Commit
9682764
·
unverified ·
1 Parent(s): 9bc0c66

refactor: base parser interface

Browse files
Files changed (1) hide show
  1. llmdataparser/base_parser.py +189 -26
llmdataparser/base_parser.py CHANGED
@@ -1,7 +1,7 @@
1
  from abc import ABC, abstractmethod
2
  from dataclasses import dataclass
3
  from functools import lru_cache
4
- from typing import Any, Generic, TypeVar
5
 
6
  import datasets
7
 
@@ -9,12 +9,17 @@ import datasets
9
  T = TypeVar("T", bound="ParseEntry")
10
 
11
 
12
- @dataclass(frozen=True)
13
  class ParseEntry:
14
  """A simple base class for entries, customizable by each dataset parser."""
15
 
 
 
 
 
16
 
17
- class DatasetParser(ABC, Generic[T]):
 
18
  """
19
  Abstract base class defining the interface for all dataset parsers.
20
  """
@@ -39,40 +44,178 @@ class DatasetParser(ABC, Generic[T]):
39
  return self._parsed_data
40
 
41
  @abstractmethod
42
- def process_entry(self, row: dict[str, Any]) -> T:
43
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
- # Base class for Hugging Face datasets
47
  class HuggingFaceDatasetParser(DatasetParser[T]):
48
  """
49
  Base class for parsers that use datasets from Hugging Face.
50
  """
51
 
52
- _data_source: str # Class variable for the dataset name
 
 
 
 
 
 
 
53
 
54
- def __init__(self):
55
- self.raw_data = None
56
- self.task_names = []
 
 
 
 
 
 
57
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- def get_task_names(self) -> list[str]:
60
- return self.task_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  @staticmethod
63
  @lru_cache(maxsize=3)
64
  def load_dataset_cached(
65
- data_source: str, config_name: str = "default", **kwargs: Any
66
  ):
67
  """
68
  Cached static method to load a dataset from Hugging Face.
69
  """
70
- return datasets.load_dataset(data_source, config_name, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def load(
73
  self,
74
- data_source: str | None = None,
75
- config_name: str = "all",
76
  trust_remote_code: bool = True,
77
  split: str | None = None,
78
  **kwargs: Any,
@@ -80,21 +223,41 @@ class HuggingFaceDatasetParser(DatasetParser[T]):
80
  """
81
  Load the dataset using the Hugging Face datasets library.
82
  """
83
- # Use class-level data_source if not provided
84
- data_source = data_source or self._data_source
85
- if not data_source:
86
- raise ValueError("The 'data_source' class variable must be defined.")
87
 
88
  # Call the cached static method
89
- self.raw_data = self.load_dataset_cached(
90
- data_source,
91
- config_name=config_name,
92
  trust_remote_code=trust_remote_code,
93
  split=split,
94
  **kwargs,
95
  )
96
- self.task_names = list(self.raw_data.keys())
 
 
 
 
 
 
 
 
97
  print(
98
- f"Loaded dataset with {len(self.task_names)} tasks: {', '.join(self.task_names)}."
99
  )
100
- # Additional common initialization can be added here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from abc import ABC, abstractmethod
2
  from dataclasses import dataclass
3
  from functools import lru_cache
4
+ from typing import Any, ClassVar, Generic, TypeVar
5
 
6
  import datasets
7
 
 
9
  T = TypeVar("T", bound="ParseEntry")
10
 
11
 
12
+ @dataclass(frozen=True, kw_only=True, slots=True)
13
  class ParseEntry:
14
  """A simple base class for entries, customizable by each dataset parser."""
15
 
16
+ prompt: str
17
+ answer: str
18
+ raw_question: str
19
+ raw_answer: str
20
 
21
+
22
+ class DatasetParser(Generic[T], ABC):
23
  """
24
  Abstract base class defining the interface for all dataset parsers.
25
  """
 
44
  return self._parsed_data
45
 
46
  @abstractmethod
47
+ def process_entry(
48
+ self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
49
+ ) -> T:
50
+ """
51
+ Process a single entry from the dataset.
52
+
53
+ Args:
54
+ row: A dictionary representing a single entry from the dataset.
55
+ task_name: Optional task name for the entry.
56
+ **kwargs: Additional keyword arguments.
57
+
58
+ Returns:
59
+ T: The processed entry, typically an instance of a subclass of ParseEntry.
60
+ """
61
+
62
+
63
+ @dataclass(frozen=True, kw_only=True, slots=True)
64
+ class HuggingFaceParseEntry(ParseEntry):
65
+ """ParseEntry with an additional task_name field."""
66
+
67
+ task_name: str
68
 
69
 
 
70
  class HuggingFaceDatasetParser(DatasetParser[T]):
71
  """
72
  Base class for parsers that use datasets from Hugging Face.
73
  """
74
 
75
+ # _data_source is the name of the dataset, e.g. "lighteval/MATH"
76
+ _data_source: ClassVar[str]
77
+ # _task_names is the list of tasks in the dataset, e.g. ["algebra", "geometry", "statistics"]
78
+ _task_names: ClassVar[list[str]]
79
+ # _default_task is the default task to use if no task is specified, e.g. "algebra"
80
+ _default_task: ClassVar[str]
81
+ # _default_system_prompt is the default system prompt to use if no system prompt is specified
82
+ _default_system_prompt: ClassVar[str]
83
 
84
+ def __init__(self, system_prompt: str | None = None, **kwargs):
85
+ """
86
+ Initialize a HuggingFaceDatasetParser.
87
+
88
+ Args:
89
+ system_prompt: Optional custom system prompt to use instead of the default.
90
+ If not provided, will use the class's _default_system_prompt.
91
+ **kwargs: Additional keyword arguments passed to the parent class.
92
+ """
93
  super().__init__()
94
+ # raw_data is the dataset loaded from HuggingFace
95
+ self.raw_data: dict[str, Any] | None = None
96
+ # split_names is the list of splits in the dataset, e.g. ["train", "test", "validation"]
97
+ self.split_names: list[str] = []
98
+ # _current_task is the task currently being processed, e.g. "algebra"
99
+ self._current_task: str = ""
100
+ # _system_prompt is the system prompt currently being used
101
+ self._system_prompt: str = system_prompt or self._default_system_prompt
102
+
103
+ def _get_current_task(self, data_entry: dict[str, Any] | None = None) -> str:
104
+ """
105
+ Get the currently loaded task name.
106
+
107
+ Args:
108
+ data_entry: Optional dictionary containing entry data that might include task information
109
+
110
+ Returns:
111
+ str: The task name from either the data entry (if available) or the currently set task
112
+ """
113
+ # If data_entry is provided and contains task information, use it
114
+ if data_entry is not None and hasattr(self, "_get_task_from_entry"):
115
+ try:
116
+ return self._get_task_from_entry(data_entry)
117
+ except (KeyError, AttributeError):
118
+ pass
119
 
120
+ # Otherwise return the task set during load()
121
+ return self._current_task or self._default_task
122
+
123
+ @property
124
+ def task_names(self) -> list[str]:
125
+ """Get all available task names."""
126
+ return self._task_names
127
+
128
+ @property
129
+ def total_tasks(self) -> int:
130
+ """Get total number of available tasks."""
131
+ return len(self._task_names)
132
+
133
+ @property
134
+ def get_huggingface_link(self) -> str:
135
+ return "https://huggingface.co/datasets/" + self._data_source
136
 
137
  @staticmethod
138
  @lru_cache(maxsize=3)
139
  def load_dataset_cached(
140
+ data_source: str, task_name: str = "default", **kwargs: Any
141
  ):
142
  """
143
  Cached static method to load a dataset from Hugging Face.
144
  """
145
+ return datasets.load_dataset(data_source, task_name, **kwargs)
146
+
147
+ def parse(
148
+ self,
149
+ split_names: str | list[str] | None = None,
150
+ force: bool = False,
151
+ **kwargs: Any,
152
+ ) -> None:
153
+ """
154
+ Parse the MATH dataset splits into structured entries.
155
+
156
+ Args:
157
+ split_names: Dataset splits to parse. Can be:
158
+ - None: Parse all available splits
159
+ - str: Parse a single split (e.g., "train")
160
+ - list[str]: Parse multiple splits (e.g., ["train", "test"])
161
+ force: If True, overwrites existing parsed data without confirmation.
162
+ If False and parsed data exists, prompts for confirmation.
163
+ **kwargs: Additional keyword arguments passed to process_entry
164
+
165
+ Raises:
166
+ ValueError: If no data is loaded or if a specified split name doesn't exist
167
+ """
168
+ if self.raw_data is None:
169
+ raise ValueError("No data loaded. Please load the dataset first.")
170
+
171
+ if self._parsed_data and not force:
172
+ response = input(
173
+ f"Found {len(self._parsed_data)} existing parsed entries. "
174
+ "Do you want to overwrite them? [y/N]: "
175
+ ).lower()
176
+ if response not in ("y", "yes"):
177
+ print("Parsing cancelled. Existing data preserved.")
178
+ return
179
+
180
+ self._parsed_data.clear()
181
+
182
+ # Dataset with splits
183
+ if split_names is None:
184
+ split_names = self.split_names
185
+ elif isinstance(split_names, str):
186
+ split_names = [split_names]
187
+
188
+ for split_name in split_names:
189
+ if split_name not in self.split_names:
190
+ raise ValueError(f"Split '{split_name}' not found in the dataset.")
191
+
192
+ dataset_split = self.raw_data[split_name]
193
+ total_entries = len(dataset_split)
194
+ print(f"Processing {split_name} split with {total_entries} entries...")
195
+
196
+ for index, entry in enumerate(dataset_split, start=1):
197
+ try:
198
+ task_name = self._get_current_task(data_entry=entry)
199
+ parsed_entry = self.process_entry(entry, task_name, **kwargs)
200
+ self._parsed_data.append(parsed_entry)
201
+
202
+ # Print progress every 100 entries
203
+ if index % 100 == 0:
204
+ print(
205
+ f"Processed {index}/{total_entries} entries from '{split_name}'"
206
+ )
207
+
208
+ except Exception as e:
209
+ print(f"Error processing entry {index} in {split_name}: {str(e)}")
210
+ continue
211
+
212
+ print(f"Completed parsing {index} entries from '{split_name}'")
213
+
214
+ print(f"Total parsed entries: {len(self._parsed_data)}")
215
 
216
  def load(
217
  self,
218
+ task_name: str | None = None,
 
219
  trust_remote_code: bool = True,
220
  split: str | None = None,
221
  **kwargs: Any,
 
223
  """
224
  Load the dataset using the Hugging Face datasets library.
225
  """
226
+ # Set the task name
227
+ self._current_task = task_name or self._default_task
 
 
228
 
229
  # Call the cached static method
230
+ raw_data = self.load_dataset_cached(
231
+ self._data_source,
232
+ task_name=self._current_task,
233
  trust_remote_code=trust_remote_code,
234
  split=split,
235
  **kwargs,
236
  )
237
+
238
+ # Handle split-specific loading
239
+ if split:
240
+ self.raw_data = {split: raw_data}
241
+ self.split_names = [split]
242
+ else:
243
+ self.raw_data = raw_data
244
+ self.split_names = list(raw_data.keys())
245
+
246
  print(
247
+ f"Loaded dataset with {len(self.split_names)} groups: {', '.join(self.split_names)}."
248
  )
249
+
250
+ def __repr__(self) -> str:
251
+ status = "loaded" if self.raw_data is not None else "not loaded"
252
+ parsed_count = len(self._parsed_data) if self._parsed_data else 0
253
+ return (
254
+ f"{self.__class__.__name__}("
255
+ f"data_source='{self._data_source}', "
256
+ f"task='{self._current_task}', "
257
+ f"status='{status}', "
258
+ f"parsed_entries={parsed_count}"
259
+ ")"
260
+ )
261
+
262
+ def __str__(self) -> str:
263
+ return self.__repr__()