File size: 7,178 Bytes
952a3b5 3aaa6f0 952a3b5 3aaa6f0 27ff91e 1223129 27ff91e 1223129 3aaa6f0 952a3b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
from dataclasses import dataclass
from typing import Any, ClassVar
from llmdataparser.base_parser import (
DatasetDescription,
EvaluationMetric,
HuggingFaceDatasetParser,
HuggingFaceParseEntry,
)
from llmdataparser.prompts import MGSM_SYSTEM_PROMPT
@dataclass(frozen=True, kw_only=True, slots=True)
class MGSMParseEntry(HuggingFaceParseEntry):
"""Custom entry class for MGSM, with fields specific to this dataset parser."""
numerical_answer: int | float
equation_solution: str | None
language: str
@classmethod
def create(
cls,
prompt: str,
answer: str,
raw_question: str,
raw_answer: str,
numerical_answer: int | float,
equation_solution: str | None,
task_name: str,
language: str,
) -> "MGSMParseEntry":
return cls(
prompt=prompt,
answer=answer,
raw_question=raw_question,
raw_answer=raw_answer,
numerical_answer=numerical_answer,
equation_solution=equation_solution,
task_name=task_name,
language=language,
)
class MGSMDatasetParser(HuggingFaceDatasetParser[MGSMParseEntry]):
"""Parser for the MGSM (Multilingual Grade School Math) dataset."""
_data_source: ClassVar[str] = "juletxara/mgsm"
_default_task: ClassVar[str] = "en"
_task_names: ClassVar[list[str]] = [
"bn",
"de",
"en",
"es",
"fr",
"ja",
"ru",
"sw",
"te",
"th",
"zh",
]
_default_system_prompt: ClassVar[str] = MGSM_SYSTEM_PROMPT
def process_entry(
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
) -> MGSMParseEntry:
"""
Process a single MGSM entry.
Args:
row: Dictionary containing the MGSM entry fields
task_name: Language code for the current task
Returns:
MGSMParseEntry: Processed entry with prompt, answer, and metadata
"""
task = task_name or self._get_current_task(row)
raw_question = row["question"]
raw_answer = row["answer"] if row["answer"] else ""
numerical_answer = row["answer_number"]
equation_solution = row["equation_solution"]
# Construct the prompt with the system prompt and question
prompt = f"{self._system_prompt}\n{raw_question}"
# Use numerical answer as string for the answer field if no detailed answer is provided
answer = raw_answer if raw_answer else str(numerical_answer)
return MGSMParseEntry.create(
prompt=prompt,
answer=answer,
raw_question=raw_question,
raw_answer=raw_answer,
numerical_answer=numerical_answer,
equation_solution=equation_solution,
task_name=task,
language=task,
)
def get_dataset_description(self) -> DatasetDescription:
"""Returns a description of the Multilingual Grade School Math dataset."""
return DatasetDescription.create(
name="Multilingual Grade School Math (MGSM)",
purpose="Evaluate multilingual chain-of-thought reasoning capabilities in mathematical problem solving",
source="https://huggingface.co/datasets/juletxara/mgsm",
language="Multilingual (11 languages)",
format="Word problems with numerical answers and solution steps",
characteristics=(
"Human-translated version of 250 GSM8K problems into 10 additional languages. "
"Each problem includes the original question from GSM8K, its translations, "
"numerical answer, and solution steps. The benchmark is designed to evaluate "
"language models' ability to perform mathematical reasoning across different languages."
),
citation="""@misc{shi2022language,
title={Language Models are Multilingual Chain-of-Thought Reasoners},
author={Freda Shi and Mirac Suzgun and Markus Freitag and Xuezhi Wang and Suraj Srivats and Soroush Vosoughi and Hyung Won Chung and Yi Tay and Sebastian Ruder and Denny Zhou and Dipanjan Das and Jason Wei},
year={2022},
eprint={2210.03057},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@article{cobbe2021gsm8k,
title={Training Verifiers to Solve Math Word Problems},
author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and Hesse, Christopher and Schulman, John},
journal={arXiv preprint arXiv:2110.14168},
year={2021}
}""",
additional_info={
"languages": [
"Bengali",
"German",
"English",
"Spanish",
"French",
"Japanese",
"Russian",
"Swahili",
"Telugu",
"Thai",
"Chinese",
],
"size": "250 problems translated into each language",
"base_dataset": "GSM8K (Grade School Math 8K)",
},
)
def get_evaluation_metrics(self) -> list[EvaluationMetric]:
"""Returns the recommended evaluation metrics for MGSM dataset."""
return [
EvaluationMetric.create(
name="exact_match",
type="string",
description="Exact match comparison between predicted and correct numerical answers",
implementation="custom_exact_match",
primary=True,
),
EvaluationMetric.create(
name="solution_validity",
type="text",
description="Assessment of whether the solution steps are mathematically valid and complete",
implementation="custom_solution_validator",
primary=True,
),
EvaluationMetric.create(
name="step_accuracy",
type="numerical",
description="Accuracy of intermediate calculation steps (e.g., <<48/2=24>>)",
implementation="custom_step_accuracy",
primary=True,
),
EvaluationMetric.create(
name="cross_lingual_consistency",
type="comparison",
description="Consistency of model performance across different language versions of the same problem",
implementation="custom_language_comparator",
primary=False,
),
]
if __name__ == "__main__":
from pprint import pprint
parser = MGSMDatasetParser()
parser.load(task_name="en") # Load French dataset
parser.parse()
parsed_data = parser.get_parsed_data
pprint(parsed_data[0].prompt)
pprint(parsed_data[0].answer)
pprint(parsed_data[0].raw_question)
pprint(parsed_data[0].numerical_answer)
pprint(parsed_data[0].language)
|