refactor: gsm8k parser
Browse files- llmdataparser/base_parser.py +2 -2
- llmdataparser/gsm8k_parser.py +60 -1
- tests/test_gsm8k_parser.py +44 -0
llmdataparser/base_parser.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
from dataclasses import dataclass
|
3 |
from functools import lru_cache
|
4 |
-
from typing import Any, ClassVar, Generic,
|
5 |
|
6 |
import datasets
|
7 |
|
@@ -130,7 +130,7 @@ class DatasetParser(Generic[T], ABC):
|
|
130 |
characteristics="Not specified",
|
131 |
)
|
132 |
|
133 |
-
def get_evaluation_metrics(self) ->
|
134 |
"""Returns the recommended evaluation metrics for the dataset."""
|
135 |
return []
|
136 |
|
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
from dataclasses import dataclass
|
3 |
from functools import lru_cache
|
4 |
+
from typing import Any, ClassVar, Generic, TypeVar
|
5 |
|
6 |
import datasets
|
7 |
|
|
|
130 |
characteristics="Not specified",
|
131 |
)
|
132 |
|
133 |
+
def get_evaluation_metrics(self) -> list[EvaluationMetric]:
|
134 |
"""Returns the recommended evaluation metrics for the dataset."""
|
135 |
return []
|
136 |
|
llmdataparser/gsm8k_parser.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import Any, ClassVar
|
3 |
|
4 |
-
from llmdataparser.base_parser import
|
|
|
|
|
|
|
|
|
|
|
5 |
from llmdataparser.prompts import GSM8K_SYSTEM_PROMPT
|
6 |
|
7 |
|
@@ -76,6 +81,60 @@ class GSM8KDatasetParser(HuggingFaceDatasetParser[GSM8KParseEntry]):
|
|
76 |
task_name=task, # Guarantee non-None
|
77 |
)
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
if __name__ == "__main__":
|
81 |
from pprint import pprint
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import Any, ClassVar
|
3 |
|
4 |
+
from llmdataparser.base_parser import (
|
5 |
+
DatasetDescription,
|
6 |
+
EvaluationMetric,
|
7 |
+
HuggingFaceDatasetParser,
|
8 |
+
HuggingFaceParseEntry,
|
9 |
+
)
|
10 |
from llmdataparser.prompts import GSM8K_SYSTEM_PROMPT
|
11 |
|
12 |
|
|
|
81 |
task_name=task, # Guarantee non-None
|
82 |
)
|
83 |
|
84 |
+
def get_dataset_description(self) -> DatasetDescription:
|
85 |
+
"""Returns description of the GSM8K dataset."""
|
86 |
+
return DatasetDescription.create(
|
87 |
+
name="Grade School Math 8K (GSM8K)",
|
88 |
+
purpose="Evaluate mathematical reasoning capabilities through word problems",
|
89 |
+
source="OpenAI",
|
90 |
+
language="English",
|
91 |
+
format="Word problems with step-by-step solutions and numerical answers",
|
92 |
+
characteristics=(
|
93 |
+
"Collection of 8.5K grade school math word problems that require "
|
94 |
+
"multi-step reasoning. Problems gradually increase in difficulty "
|
95 |
+
"and cover basic arithmetic, word problems, and elementary algebra"
|
96 |
+
),
|
97 |
+
citation="""@article{cobbe2021gsm8k,
|
98 |
+
title={Training Verifiers to Solve Math Word Problems},
|
99 |
+
author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and Hesse, Christopher and Schulman, John},
|
100 |
+
journal={arXiv preprint arXiv:2110.14168},
|
101 |
+
year={2021}
|
102 |
+
}""",
|
103 |
+
)
|
104 |
+
|
105 |
+
def get_evaluation_metrics(self) -> list[EvaluationMetric]:
|
106 |
+
"""Returns recommended evaluation metrics for GSM8K."""
|
107 |
+
return [
|
108 |
+
EvaluationMetric.create(
|
109 |
+
name="exact_match",
|
110 |
+
type="string",
|
111 |
+
description="Exact match comparison between predicted and correct numerical answers",
|
112 |
+
implementation="custom_exact_match",
|
113 |
+
primary=True,
|
114 |
+
),
|
115 |
+
EvaluationMetric.create(
|
116 |
+
name="solution_validity",
|
117 |
+
type="text",
|
118 |
+
description="Assessment of whether the solution steps are mathematically valid and complete",
|
119 |
+
implementation="custom_solution_validator",
|
120 |
+
primary=True,
|
121 |
+
),
|
122 |
+
EvaluationMetric.create(
|
123 |
+
name="step_accuracy",
|
124 |
+
type="numerical",
|
125 |
+
description="Accuracy of intermediate calculation steps (e.g., <<48/2=24>>)",
|
126 |
+
implementation="custom_step_accuracy",
|
127 |
+
primary=True,
|
128 |
+
),
|
129 |
+
EvaluationMetric.create(
|
130 |
+
name="step_count",
|
131 |
+
type="numerical",
|
132 |
+
description="Analysis of the number of reasoning steps in solutions",
|
133 |
+
implementation="custom_step_counter",
|
134 |
+
primary=False,
|
135 |
+
),
|
136 |
+
]
|
137 |
+
|
138 |
|
139 |
if __name__ == "__main__":
|
140 |
from pprint import pprint
|
tests/test_gsm8k_parser.py
CHANGED
@@ -181,3 +181,47 @@ def test_different_splits_parsing(gsm8k_parser):
|
|
181 |
assert test_count > 0
|
182 |
assert train_count > 0
|
183 |
assert train_count != test_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
assert test_count > 0
|
182 |
assert train_count > 0
|
183 |
assert train_count != test_count
|
184 |
+
|
185 |
+
|
186 |
+
def test_get_dataset_description(gsm8k_parser):
|
187 |
+
"""Test dataset description generation."""
|
188 |
+
description = gsm8k_parser.get_dataset_description()
|
189 |
+
|
190 |
+
assert description.name == "Grade School Math 8K (GSM8K)"
|
191 |
+
assert description.source == "OpenAI"
|
192 |
+
assert description.language == "English"
|
193 |
+
assert "8.5K grade school math word problems" in description.characteristics
|
194 |
+
assert "Training Verifiers to Solve Math Word Problems" in description.citation
|
195 |
+
assert "Cobbe" in description.citation
|
196 |
+
assert "arXiv" in description.citation
|
197 |
+
|
198 |
+
|
199 |
+
def test_get_evaluation_metrics(gsm8k_parser):
|
200 |
+
"""Test evaluation metrics specification."""
|
201 |
+
metrics = gsm8k_parser.get_evaluation_metrics()
|
202 |
+
|
203 |
+
# Check we have all expected metrics
|
204 |
+
metric_names = {metric.name for metric in metrics}
|
205 |
+
expected_names = {"exact_match", "solution_validity", "step_accuracy", "step_count"}
|
206 |
+
assert metric_names == expected_names
|
207 |
+
|
208 |
+
# Check exact_match metric details
|
209 |
+
exact_match = next(m for m in metrics if m.name == "exact_match")
|
210 |
+
assert exact_match.type == "string"
|
211 |
+
assert exact_match.primary is True
|
212 |
+
assert "exact match" in exact_match.description.lower()
|
213 |
+
|
214 |
+
# Check solution_validity metric details
|
215 |
+
solution_validity = next(m for m in metrics if m.name == "solution_validity")
|
216 |
+
assert solution_validity.type == "text"
|
217 |
+
assert solution_validity.primary is True
|
218 |
+
assert "valid" in solution_validity.description.lower()
|
219 |
+
|
220 |
+
# Check step metrics
|
221 |
+
step_accuracy = next(m for m in metrics if m.name == "step_accuracy")
|
222 |
+
assert step_accuracy.type == "numerical"
|
223 |
+
assert step_accuracy.primary is True
|
224 |
+
|
225 |
+
step_count = next(m for m in metrics if m.name == "step_count")
|
226 |
+
assert step_count.type == "numerical"
|
227 |
+
assert step_count.primary is False
|