JeffYang52415 commited on
Commit
424ff6a
·
unverified ·
1 Parent(s): 44529bb

feat: add gsm8k parser

Browse files
llmdataparser/gsm8k_parser.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, ClassVar
3
+
4
+ from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
5
+ from llmdataparser.prompts import GSM8K_SYSTEM_PROMPT
6
+
7
+
8
+ @dataclass(frozen=True, kw_only=True, slots=True)
9
+ class GSM8KParseEntry(HuggingFaceParseEntry):
10
+ """Custom entry class for GSM8K, with fields specific to this dataset parser."""
11
+
12
+ solution: str
13
+ numerical_answer: int | float
14
+ task_name: str
15
+
16
+ @classmethod
17
+ def create(
18
+ cls,
19
+ prompt: str,
20
+ answer: str,
21
+ raw_question: str,
22
+ raw_answer: str,
23
+ solution: str,
24
+ numerical_answer: int | float,
25
+ task_name: str,
26
+ ) -> "GSM8KParseEntry":
27
+ return cls(
28
+ prompt=prompt,
29
+ answer=answer,
30
+ raw_question=raw_question,
31
+ raw_answer=raw_answer,
32
+ solution=solution,
33
+ numerical_answer=numerical_answer,
34
+ task_name=task_name,
35
+ )
36
+
37
+
38
+ class GSM8KDatasetParser(HuggingFaceDatasetParser[GSM8KParseEntry]):
39
+ """Parser for the GSM8K dataset."""
40
+
41
+ _data_source: ClassVar[str] = "openai/gsm8k"
42
+ _task_names: ClassVar[list[str]] = ["main", "socratic"]
43
+ _default_task: ClassVar[str] = "main"
44
+ _default_system_prompt: ClassVar[str] = GSM8K_SYSTEM_PROMPT
45
+
46
+ def process_entry(
47
+ self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
48
+ ) -> GSM8KParseEntry:
49
+ """Process a single GSM8K entry."""
50
+ task = task_name or self._get_current_task(row)
51
+ raw_question = row["question"]
52
+ raw_answer = row["answer"]
53
+
54
+ # Extract numerical answer (always after '####' in GSM8K)
55
+ numerical_str = raw_answer.split("####")[-1].strip().replace(",", "")
56
+ # Convert string to number
57
+ try:
58
+ numerical_answer = float(numerical_str)
59
+ if numerical_answer.is_integer():
60
+ numerical_answer = int(numerical_answer)
61
+ except ValueError:
62
+ raise ValueError(f"Could not convert '{numerical_str}' to number")
63
+
64
+ # Extract solution (everything before '####')
65
+ solution = raw_answer.split("####")[0].strip()
66
+
67
+ prompt = f"{self._system_prompt}\n{raw_question}"
68
+
69
+ return GSM8KParseEntry.create(
70
+ prompt=prompt,
71
+ answer=str(numerical_answer),
72
+ raw_question=raw_question,
73
+ raw_answer=raw_answer,
74
+ solution=solution,
75
+ numerical_answer=numerical_answer, # Now guaranteed to be int or float
76
+ task_name=task, # Guarantee non-None
77
+ )
78
+
79
+
80
+ if __name__ == "__main__":
81
+ from pprint import pprint
82
+
83
+ parser = GSM8KDatasetParser()
84
+ parser.load()
85
+ parser.parse()
86
+
87
+ parsed_data = parser.get_parsed_data
88
+ pprint(parsed_data[0].prompt)
89
+ pprint(parsed_data[0].answer)
90
+ pprint(parsed_data[0].raw_question)
91
+ pprint(parsed_data[0].raw_answer)
92
+ pprint(parsed_data[0].solution)
93
+ pprint(parsed_data[0].numerical_answer)
tests/test_gsm8k_parser.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from llmdataparser.gsm8k_parser import GSM8KDatasetParser, GSM8KParseEntry
4
+
5
+
6
+ @pytest.fixture
7
+ def gsm8k_parser():
8
+ """Create a GSM8K parser instance for testing."""
9
+ return GSM8KDatasetParser()
10
+
11
+
12
+ @pytest.fixture
13
+ def loaded_gsm8k_parser(gsm8k_parser):
14
+ """Create and load a GSM8K parser instance for testing."""
15
+ gsm8k_parser.load(
16
+ task_name="main", split="test"
17
+ ) # Using test split as it's smaller
18
+ return gsm8k_parser
19
+
20
+
21
+ @pytest.fixture
22
+ def sample_row():
23
+ """Create a sample GSM8K data row for testing."""
24
+ return {
25
+ "question": "Janet has 3 apples. She buys 2 more. How many apples does she have now?",
26
+ "answer": "Let's solve this step by step:\n1) Initially, Janet has 3 apples\n2) She buys 2 more apples\n3) Total apples = 3 + 2\n#### 5",
27
+ }
28
+
29
+
30
+ def test_gsm8k_parse_entry_creation_valid():
31
+ """Test valid creation of GSM8KParseEntry."""
32
+ entry = GSM8KParseEntry.create(
33
+ prompt="Test prompt",
34
+ answer="5",
35
+ raw_question="Test question",
36
+ raw_answer="Solution steps #### 5",
37
+ solution="Solution steps",
38
+ task_name="main",
39
+ numerical_answer=5,
40
+ )
41
+ assert isinstance(entry, GSM8KParseEntry)
42
+ assert entry.prompt == "Test prompt"
43
+ assert entry.answer == "5"
44
+ assert entry.solution == "Solution steps"
45
+ assert entry.numerical_answer == 5
46
+ assert entry.task_name == "main"
47
+
48
+
49
+ def test_gsm8k_parser_initialization(gsm8k_parser):
50
+ """Test GSM8K parser initialization."""
51
+ assert gsm8k_parser._data_source == "openai/gsm8k"
52
+ assert gsm8k_parser._default_task == "main"
53
+ assert gsm8k_parser._task_names == ["main", "socratic"]
54
+ assert (
55
+ gsm8k_parser.get_huggingface_link
56
+ == "https://huggingface.co/datasets/openai/gsm8k"
57
+ )
58
+
59
+
60
+ def test_load_dataset(loaded_gsm8k_parser):
61
+ """Test loading the dataset."""
62
+ assert loaded_gsm8k_parser.raw_data is not None
63
+ assert loaded_gsm8k_parser.split_names == [
64
+ "test"
65
+ ] # Since we specifically loaded the test split
66
+ assert loaded_gsm8k_parser._current_task == "main"
67
+
68
+
69
+ @pytest.mark.integration
70
+ def test_full_parse_workflow(loaded_gsm8k_parser):
71
+ """Test the complete workflow of loading and parsing data."""
72
+ # Parse the test split
73
+ loaded_gsm8k_parser.parse(split_names="test", force=True)
74
+ parsed_data = loaded_gsm8k_parser.get_parsed_data
75
+
76
+ # Basic checks
77
+ assert len(parsed_data) > 0
78
+
79
+ # Check first entry structure
80
+ first_entry = parsed_data[0]
81
+ assert isinstance(first_entry, GSM8KParseEntry)
82
+ assert first_entry.task_name == "main"
83
+ assert isinstance(first_entry.numerical_answer, (str, int, float))
84
+ assert "####" in first_entry.raw_answer
85
+ assert first_entry.solution
86
+ assert first_entry.prompt.startswith(loaded_gsm8k_parser._system_prompt)
87
+
88
+
89
+ def test_process_entry(gsm8k_parser, sample_row):
90
+ """Test processing of a single GSM8K entry."""
91
+ entry = gsm8k_parser.process_entry(sample_row, task_name="main")
92
+
93
+ assert isinstance(entry, GSM8KParseEntry)
94
+ assert entry.numerical_answer == 5
95
+ assert "Janet has 3 apples" in entry.raw_question
96
+ assert "#### 5" in entry.raw_answer
97
+ assert "Let's solve this step by step:" in entry.solution
98
+ assert gsm8k_parser._system_prompt in entry.prompt
99
+ assert entry.task_name == "main"
100
+
101
+
102
+ @pytest.mark.parametrize("split_name", ["invalid_split", "wrong_split"])
103
+ def test_parse_with_invalid_split(gsm8k_parser, split_name):
104
+ """Test parsing with invalid split names."""
105
+ gsm8k_parser.raw_data = {"train": [], "test": []} # Mock data
106
+
107
+ with pytest.raises(
108
+ ValueError, match=f"Split '{split_name}' not found in the dataset"
109
+ ):
110
+ gsm8k_parser.parse(split_name)
111
+
112
+
113
+ def test_parse_without_loaded_data(gsm8k_parser):
114
+ """Test parsing without loading data first."""
115
+ with pytest.raises(
116
+ ValueError, match="No data loaded. Please load the dataset first"
117
+ ):
118
+ gsm8k_parser.parse()
119
+
120
+
121
+ @pytest.mark.parametrize(
122
+ "test_case",
123
+ [
124
+ {"question": "Test question", "answer": "Some solution steps #### 42"},
125
+ {
126
+ "question": "Test question",
127
+ "answer": "Complex solution\nWith multiple lines\n#### 123.45",
128
+ },
129
+ {"question": "Test question", "answer": "No steps #### 0"},
130
+ ],
131
+ )
132
+ def test_numerical_answer_extraction(gsm8k_parser, test_case):
133
+ """Test extraction of numerical answers from different formats."""
134
+ entry = gsm8k_parser.process_entry(test_case, task_name="main")
135
+ assert str(entry.numerical_answer) == test_case["answer"].split("####")[
136
+ -1
137
+ ].strip().replace(",", "")
138
+
139
+
140
+ def test_solution_extraction(gsm8k_parser):
141
+ """Test extraction of solution steps."""
142
+ row = {
143
+ "question": "Test question",
144
+ "answer": "Step 1: Do this\nStep 2: Do that\n#### 42",
145
+ }
146
+
147
+ entry = gsm8k_parser.process_entry(row, task_name="main")
148
+ assert entry.solution == "Step 1: Do this\nStep 2: Do that"
149
+ assert entry.task_name == "main"
150
+ assert "####" not in entry.solution
151
+
152
+
153
+ def test_parser_properties(gsm8k_parser):
154
+ """Test parser property getters."""
155
+ assert gsm8k_parser.task_names == ["main", "socratic"]
156
+ assert gsm8k_parser.total_tasks == 2
157
+
158
+
159
+ def test_parser_string_representation(loaded_gsm8k_parser):
160
+ """Test string representation of parser."""
161
+ repr_str = str(loaded_gsm8k_parser)
162
+ assert "GSM8KDatasetParser" in repr_str
163
+ assert "openai/gsm8k" in repr_str
164
+ assert "main" in repr_str
165
+ assert "loaded" in repr_str
166
+
167
+
168
+ @pytest.mark.integration
169
+ def test_different_splits_parsing(gsm8k_parser):
170
+ """Test parsing different splits of the dataset."""
171
+ # Load and parse test split
172
+ gsm8k_parser.load(task_name="main", split="test")
173
+ gsm8k_parser.parse(split_names="test", force=True)
174
+ test_count = len(gsm8k_parser.get_parsed_data)
175
+
176
+ # Load and parse train split
177
+ gsm8k_parser.load(task_name="main", split="train")
178
+ gsm8k_parser.parse(split_names="train", force=True)
179
+ train_count = len(gsm8k_parser.get_parsed_data)
180
+
181
+ assert test_count > 0
182
+ assert train_count > 0
183
+ assert train_count != test_count