File size: 5,514 Bytes
424ff6a
 
 
e5427e0
 
 
 
 
 
424ff6a
 
 
 
 
 
 
 
 
 
 
 
 
0450c4e
424ff6a
 
 
 
 
 
 
 
0450c4e
424ff6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0450c4e
424ff6a
 
0450c4e
424ff6a
 
 
 
 
 
 
 
e5427e0
 
 
 
 
 
 
 
a06316f
e5427e0
 
 
 
 
 
27ff91e
 
 
 
 
e5427e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424ff6a
 
 
 
 
 
 
 
 
0450c4e
424ff6a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from dataclasses import dataclass
from typing import Any, ClassVar

from llmdataparser.base_parser import (
    DatasetDescription,
    EvaluationMetric,
    HuggingFaceDatasetParser,
    HuggingFaceParseEntry,
)


@dataclass(frozen=True, kw_only=True, slots=True)
class GSM8KParseEntry(HuggingFaceParseEntry):
    """Custom entry class for GSM8K, with fields specific to this dataset parser."""

    solution: str
    numerical_answer: int | float
    task_name: str

    @classmethod
    def create(
        cls,
        question: str,
        answer: str,
        raw_question: str,
        raw_answer: str,
        solution: str,
        numerical_answer: int | float,
        task_name: str,
    ) -> "GSM8KParseEntry":
        return cls(
            question=question,
            answer=answer,
            raw_question=raw_question,
            raw_answer=raw_answer,
            solution=solution,
            numerical_answer=numerical_answer,
            task_name=task_name,
        )


class GSM8KDatasetParser(HuggingFaceDatasetParser[GSM8KParseEntry]):
    """Parser for the GSM8K dataset."""

    _data_source: ClassVar[str] = "openai/gsm8k"
    _task_names: ClassVar[list[str]] = ["main", "socratic"]
    _default_task: ClassVar[str] = "main"

    def process_entry(
        self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
    ) -> GSM8KParseEntry:
        """Process a single GSM8K entry."""
        task = task_name or self._get_current_task(row)
        raw_question = row["question"]
        raw_answer = row["answer"]

        # Extract numerical answer (always after '####' in GSM8K)
        numerical_str = raw_answer.split("####")[-1].strip().replace(",", "")
        # Convert string to number
        try:
            numerical_answer = float(numerical_str)
            if numerical_answer.is_integer():
                numerical_answer = int(numerical_answer)
        except ValueError:
            raise ValueError(f"Could not convert '{numerical_str}' to number")

        # Extract solution (everything before '####')
        solution = raw_answer.split("####")[0].strip()

        question = str(raw_question)

        return GSM8KParseEntry.create(
            question=question,
            answer=str(numerical_answer),
            raw_question=raw_question,
            raw_answer=raw_answer,
            solution=solution,
            numerical_answer=numerical_answer,  # Now guaranteed to be int or float
            task_name=task,  # Guarantee non-None
        )

    def get_dataset_description(self) -> DatasetDescription:
        """Returns description of the GSM8K dataset."""
        return DatasetDescription.create(
            name="Grade School Math 8K (GSM8K)",
            purpose="Evaluate mathematical reasoning capabilities through word problems",
            source="OpenAI",
            language="English",
            format="Word problems with step-by-step solutions and numerical answers",
            category=["Math"],
            characteristics=(
                "Collection of 8.5K grade school math word problems that require "
                "multi-step reasoning. Problems gradually increase in difficulty "
                "and cover basic arithmetic, word problems, and elementary algebra"
            ),
            citation="""@article{cobbe2021gsm8k,
    title={Training Verifiers to Solve Math Word Problems},
    author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and Hesse, Christopher and Schulman, John},
    journal={arXiv preprint arXiv:2110.14168},
    year={2021}
    }""",
        )

    def get_evaluation_metrics(self) -> list[EvaluationMetric]:
        """Returns recommended evaluation metrics for GSM8K."""
        return [
            EvaluationMetric.create(
                name="exact_match",
                type="string",
                description="Exact match comparison between predicted and correct numerical answers",
                implementation="custom_exact_match",
                primary=True,
            ),
            EvaluationMetric.create(
                name="solution_validity",
                type="text",
                description="Assessment of whether the solution steps are mathematically valid and complete",
                implementation="custom_solution_validator",
                primary=True,
            ),
            EvaluationMetric.create(
                name="step_accuracy",
                type="numerical",
                description="Accuracy of intermediate calculation steps (e.g., <<48/2=24>>)",
                implementation="custom_step_accuracy",
                primary=True,
            ),
            EvaluationMetric.create(
                name="step_count",
                type="numerical",
                description="Analysis of the number of reasoning steps in solutions",
                implementation="custom_step_counter",
                primary=False,
            ),
        ]


if __name__ == "__main__":
    from pprint import pprint

    parser = GSM8KDatasetParser()
    parser.load()
    parser.parse()

    parsed_data = parser.get_parsed_data
    pprint(parsed_data[0].question)
    pprint(parsed_data[0].answer)
    pprint(parsed_data[0].raw_question)
    pprint(parsed_data[0].raw_answer)
    pprint(parsed_data[0].solution)
    pprint(parsed_data[0].numerical_answer)