File size: 5,610 Bytes
424ff6a e5427e0 424ff6a e5427e0 27ff91e e5427e0 424ff6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from dataclasses import dataclass
from typing import Any, ClassVar
from llmdataparser.base_parser import (
DatasetDescription,
EvaluationMetric,
HuggingFaceDatasetParser,
HuggingFaceParseEntry,
)
from llmdataparser.prompts import GSM8K_SYSTEM_PROMPT
@dataclass(frozen=True, kw_only=True, slots=True)
class GSM8KParseEntry(HuggingFaceParseEntry):
"""Custom entry class for GSM8K, with fields specific to this dataset parser."""
solution: str
numerical_answer: int | float
task_name: str
@classmethod
def create(
cls,
prompt: str,
answer: str,
raw_question: str,
raw_answer: str,
solution: str,
numerical_answer: int | float,
task_name: str,
) -> "GSM8KParseEntry":
return cls(
prompt=prompt,
answer=answer,
raw_question=raw_question,
raw_answer=raw_answer,
solution=solution,
numerical_answer=numerical_answer,
task_name=task_name,
)
class GSM8KDatasetParser(HuggingFaceDatasetParser[GSM8KParseEntry]):
"""Parser for the GSM8K dataset."""
_data_source: ClassVar[str] = "openai/gsm8k"
_task_names: ClassVar[list[str]] = ["main", "socratic"]
_default_task: ClassVar[str] = "main"
_default_system_prompt: ClassVar[str] = GSM8K_SYSTEM_PROMPT
def process_entry(
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
) -> GSM8KParseEntry:
"""Process a single GSM8K entry."""
task = task_name or self._get_current_task(row)
raw_question = row["question"]
raw_answer = row["answer"]
# Extract numerical answer (always after '####' in GSM8K)
numerical_str = raw_answer.split("####")[-1].strip().replace(",", "")
# Convert string to number
try:
numerical_answer = float(numerical_str)
if numerical_answer.is_integer():
numerical_answer = int(numerical_answer)
except ValueError:
raise ValueError(f"Could not convert '{numerical_str}' to number")
# Extract solution (everything before '####')
solution = raw_answer.split("####")[0].strip()
prompt = f"{self._system_prompt}\n{raw_question}"
return GSM8KParseEntry.create(
prompt=prompt,
answer=str(numerical_answer),
raw_question=raw_question,
raw_answer=raw_answer,
solution=solution,
numerical_answer=numerical_answer, # Now guaranteed to be int or float
task_name=task, # Guarantee non-None
)
def get_dataset_description(self) -> DatasetDescription:
"""Returns description of the GSM8K dataset."""
return DatasetDescription.create(
name="Grade School Math 8K (GSM8K)",
purpose="Evaluate mathematical reasoning capabilities through word problems",
source="OpenAI",
language="English",
format="Word problems with step-by-step solutions and numerical answers",
characteristics=(
"Collection of 8.5K grade school math word problems that require "
"multi-step reasoning. Problems gradually increase in difficulty "
"and cover basic arithmetic, word problems, and elementary algebra"
),
citation="""@article{cobbe2021gsm8k,
title={Training Verifiers to Solve Math Word Problems},
author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and Hesse, Christopher and Schulman, John},
journal={arXiv preprint arXiv:2110.14168},
year={2021}
}""",
)
def get_evaluation_metrics(self) -> list[EvaluationMetric]:
"""Returns recommended evaluation metrics for GSM8K."""
return [
EvaluationMetric.create(
name="exact_match",
type="string",
description="Exact match comparison between predicted and correct numerical answers",
implementation="custom_exact_match",
primary=True,
),
EvaluationMetric.create(
name="solution_validity",
type="text",
description="Assessment of whether the solution steps are mathematically valid and complete",
implementation="custom_solution_validator",
primary=True,
),
EvaluationMetric.create(
name="step_accuracy",
type="numerical",
description="Accuracy of intermediate calculation steps (e.g., <<48/2=24>>)",
implementation="custom_step_accuracy",
primary=True,
),
EvaluationMetric.create(
name="step_count",
type="numerical",
description="Analysis of the number of reasoning steps in solutions",
implementation="custom_step_counter",
primary=False,
),
]
if __name__ == "__main__":
from pprint import pprint
parser = GSM8KDatasetParser()
parser.load()
parser.parse()
parsed_data = parser.get_parsed_data
pprint(parsed_data[0].prompt)
pprint(parsed_data[0].answer)
pprint(parsed_data[0].raw_question)
pprint(parsed_data[0].raw_answer)
pprint(parsed_data[0].solution)
pprint(parsed_data[0].numerical_answer)
|