JeffYang52415 commited on
Commit
e5427e0
·
unverified ·
1 Parent(s): 58d5612

refactor: gsm8k parser

Browse files
llmdataparser/base_parser.py CHANGED
@@ -1,7 +1,7 @@
1
  from abc import ABC, abstractmethod
2
  from dataclasses import dataclass
3
  from functools import lru_cache
4
- from typing import Any, ClassVar, Generic, List, TypeVar
5
 
6
  import datasets
7
 
@@ -130,7 +130,7 @@ class DatasetParser(Generic[T], ABC):
130
  characteristics="Not specified",
131
  )
132
 
133
- def get_evaluation_metrics(self) -> List[EvaluationMetric]:
134
  """Returns the recommended evaluation metrics for the dataset."""
135
  return []
136
 
 
1
  from abc import ABC, abstractmethod
2
  from dataclasses import dataclass
3
  from functools import lru_cache
4
+ from typing import Any, ClassVar, Generic, TypeVar
5
 
6
  import datasets
7
 
 
130
  characteristics="Not specified",
131
  )
132
 
133
+ def get_evaluation_metrics(self) -> list[EvaluationMetric]:
134
  """Returns the recommended evaluation metrics for the dataset."""
135
  return []
136
 
llmdataparser/gsm8k_parser.py CHANGED
@@ -1,7 +1,12 @@
1
  from dataclasses import dataclass
2
  from typing import Any, ClassVar
3
 
4
- from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
 
 
 
 
 
5
  from llmdataparser.prompts import GSM8K_SYSTEM_PROMPT
6
 
7
 
@@ -76,6 +81,60 @@ class GSM8KDatasetParser(HuggingFaceDatasetParser[GSM8KParseEntry]):
76
  task_name=task, # Guarantee non-None
77
  )
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  if __name__ == "__main__":
81
  from pprint import pprint
 
1
  from dataclasses import dataclass
2
  from typing import Any, ClassVar
3
 
4
+ from llmdataparser.base_parser import (
5
+ DatasetDescription,
6
+ EvaluationMetric,
7
+ HuggingFaceDatasetParser,
8
+ HuggingFaceParseEntry,
9
+ )
10
  from llmdataparser.prompts import GSM8K_SYSTEM_PROMPT
11
 
12
 
 
81
  task_name=task, # Guarantee non-None
82
  )
83
 
84
+ def get_dataset_description(self) -> DatasetDescription:
85
+ """Returns description of the GSM8K dataset."""
86
+ return DatasetDescription.create(
87
+ name="Grade School Math 8K (GSM8K)",
88
+ purpose="Evaluate mathematical reasoning capabilities through word problems",
89
+ source="OpenAI",
90
+ language="English",
91
+ format="Word problems with step-by-step solutions and numerical answers",
92
+ characteristics=(
93
+ "Collection of 8.5K grade school math word problems that require "
94
+ "multi-step reasoning. Problems gradually increase in difficulty "
95
+ "and cover basic arithmetic, word problems, and elementary algebra"
96
+ ),
97
+ citation="""@article{cobbe2021gsm8k,
98
+ title={Training Verifiers to Solve Math Word Problems},
99
+ author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and Hesse, Christopher and Schulman, John},
100
+ journal={arXiv preprint arXiv:2110.14168},
101
+ year={2021}
102
+ }""",
103
+ )
104
+
105
+ def get_evaluation_metrics(self) -> list[EvaluationMetric]:
106
+ """Returns recommended evaluation metrics for GSM8K."""
107
+ return [
108
+ EvaluationMetric.create(
109
+ name="exact_match",
110
+ type="string",
111
+ description="Exact match comparison between predicted and correct numerical answers",
112
+ implementation="custom_exact_match",
113
+ primary=True,
114
+ ),
115
+ EvaluationMetric.create(
116
+ name="solution_validity",
117
+ type="text",
118
+ description="Assessment of whether the solution steps are mathematically valid and complete",
119
+ implementation="custom_solution_validator",
120
+ primary=True,
121
+ ),
122
+ EvaluationMetric.create(
123
+ name="step_accuracy",
124
+ type="numerical",
125
+ description="Accuracy of intermediate calculation steps (e.g., <<48/2=24>>)",
126
+ implementation="custom_step_accuracy",
127
+ primary=True,
128
+ ),
129
+ EvaluationMetric.create(
130
+ name="step_count",
131
+ type="numerical",
132
+ description="Analysis of the number of reasoning steps in solutions",
133
+ implementation="custom_step_counter",
134
+ primary=False,
135
+ ),
136
+ ]
137
+
138
 
139
  if __name__ == "__main__":
140
  from pprint import pprint
tests/test_gsm8k_parser.py CHANGED
@@ -181,3 +181,47 @@ def test_different_splits_parsing(gsm8k_parser):
181
  assert test_count > 0
182
  assert train_count > 0
183
  assert train_count != test_count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  assert test_count > 0
182
  assert train_count > 0
183
  assert train_count != test_count
184
+
185
+
186
+ def test_get_dataset_description(gsm8k_parser):
187
+ """Test dataset description generation."""
188
+ description = gsm8k_parser.get_dataset_description()
189
+
190
+ assert description.name == "Grade School Math 8K (GSM8K)"
191
+ assert description.source == "OpenAI"
192
+ assert description.language == "English"
193
+ assert "8.5K grade school math word problems" in description.characteristics
194
+ assert "Training Verifiers to Solve Math Word Problems" in description.citation
195
+ assert "Cobbe" in description.citation
196
+ assert "arXiv" in description.citation
197
+
198
+
199
+ def test_get_evaluation_metrics(gsm8k_parser):
200
+ """Test evaluation metrics specification."""
201
+ metrics = gsm8k_parser.get_evaluation_metrics()
202
+
203
+ # Check we have all expected metrics
204
+ metric_names = {metric.name for metric in metrics}
205
+ expected_names = {"exact_match", "solution_validity", "step_accuracy", "step_count"}
206
+ assert metric_names == expected_names
207
+
208
+ # Check exact_match metric details
209
+ exact_match = next(m for m in metrics if m.name == "exact_match")
210
+ assert exact_match.type == "string"
211
+ assert exact_match.primary is True
212
+ assert "exact match" in exact_match.description.lower()
213
+
214
+ # Check solution_validity metric details
215
+ solution_validity = next(m for m in metrics if m.name == "solution_validity")
216
+ assert solution_validity.type == "text"
217
+ assert solution_validity.primary is True
218
+ assert "valid" in solution_validity.description.lower()
219
+
220
+ # Check step metrics
221
+ step_accuracy = next(m for m in metrics if m.name == "step_accuracy")
222
+ assert step_accuracy.type == "numerical"
223
+ assert step_accuracy.primary is True
224
+
225
+ step_count = next(m for m in metrics if m.name == "step_count")
226
+ assert step_count.type == "numerical"
227
+ assert step_count.primary is False