File size: 10,251 Bytes
44529bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0450c4e
44529bb
 
 
 
 
 
 
0450c4e
44529bb
 
 
 
 
 
 
 
 
 
 
 
0450c4e
44529bb
 
 
 
 
 
 
 
 
 
 
 
 
 
0450c4e
 
 
 
44529bb
 
18bf871
44529bb
 
 
 
 
 
0450c4e
44529bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0450c4e
44529bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18bf871
44529bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
793be05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import pytest

from llmdataparser.mmlu_parser import (
    BaseMMLUDatasetParser,
    MMLUParseEntry,
    MMLUProDatasetParser,
    MMLUProParseEntry,
    MMLUReduxDatasetParser,
    TMMLUPlusDatasetParser,
)


@pytest.fixture
def base_parser():
    """Create a base MMLU parser instance."""
    return BaseMMLUDatasetParser()


@pytest.fixture
def redux_parser():
    """Create a MMLU Redux parser instance."""
    return MMLUReduxDatasetParser()


@pytest.fixture
def tmmlu_parser():
    """Create a TMMLU+ parser instance."""
    return TMMLUPlusDatasetParser()


@pytest.fixture
def mmlu_pro_parser():
    """Create a MMLU Pro parser instance."""
    return MMLUProDatasetParser()


@pytest.fixture
def sample_mmlu_entries():
    """Create sample MMLU dataset entries for testing."""
    return [
        {
            "question": "What is the capital of France?",
            "choices": ["London", "Paris", "Berlin", "Madrid"],
            "answer": 1,  # Paris
            "subject": "geography",
        },
        {
            "question": "Which of these is a primary color?",
            "choices": ["Green", "Purple", "Blue", "Orange"],
            "answer": 2,  # Blue
            "subject": "art",
        },
    ]


@pytest.fixture
def sample_mmlu_pro_entries():
    """Create sample MMLU Pro dataset entries for testing."""
    return [
        {
            "question": "What is the time complexity of quicksort?",
            "options": ["O(n)", "O(n log n)", "O(n²)", "O(2ⁿ)", "O(n!)", "O(1)"],
            "answer": "The average time complexity of quicksort is O(n log n)",
            "answer_index": 1,
            "category": "computer_science",
        }
    ]


def test_mmlu_parse_entry_creation_valid():
    """Test valid creation of MMLUParseEntry."""
    entry = MMLUParseEntry.create(
        question="Test question",
        answer="A",
        raw_question="Test question",
        raw_choices=["choice1", "choice2", "choice3", "choice4"],
        raw_answer="0",
        task_name="test_task",
    )
    assert isinstance(entry, MMLUParseEntry)
    assert entry.question == "Test question"
    assert entry.answer == "A"
    assert entry.raw_choices == ["choice1", "choice2", "choice3", "choice4"]
    assert entry.task_name == "test_task"


@pytest.mark.parametrize("invalid_answer", ["E", "F", "1", "", None])
def test_mmlu_parse_entry_creation_invalid(invalid_answer):
    """Test invalid answer handling in MMLUParseEntry creation."""
    with pytest.raises(
        ValueError, match="Invalid answer_letter.*must be one of A, B, C, D"
    ):
        MMLUParseEntry.create(
            question="Test question",
            answer=invalid_answer,
            raw_question="Test question",
            raw_choices=["choice1", "choice2", "choice3", "choice4"],
            raw_answer="4",
            task_name="test_task",
        )


def test_process_entry_base(base_parser, sample_mmlu_entries):
    """Test processing entries in base MMLU parser."""
    entry = base_parser.process_entry(sample_mmlu_entries[0], task_name="geography")

    assert isinstance(entry, MMLUParseEntry)
    assert entry.answer == "B"  # Index 1 maps to B
    assert "A. London" in entry.question
    assert "B. Paris" in entry.question
    assert "C. Berlin" in entry.question
    assert "D. Madrid" in entry.question
    assert entry.raw_question == "What is the capital of France?"
    assert entry.raw_choices == ["London", "Paris", "Berlin", "Madrid"]
    assert entry.raw_answer == "1"
    assert entry.task_name == "geography"


def test_mmlu_pro_parse_entry_creation_valid():
    """Test valid creation of MMLUProParseEntry."""
    entry = MMLUProParseEntry.create(
        question="Test question",
        answer="E",  # MMLU Pro supports up to J
        raw_question="Test question",
        raw_choices=["choice1", "choice2", "choice3", "choice4", "choice5"],
        raw_answer="4",
        task_name="test_task",
    )
    assert isinstance(entry, MMLUProParseEntry)
    assert entry.answer == "E"
    assert len(entry.raw_choices) == 5


def test_process_entry_mmlu_pro(mmlu_pro_parser, sample_mmlu_pro_entries):
    """Test processing entries in MMLU Pro parser."""
    entry = mmlu_pro_parser.process_entry(
        sample_mmlu_pro_entries[0], task_name="computer_science"
    )

    assert isinstance(entry, MMLUProParseEntry)
    assert entry.answer == "B"  # Index 1 maps to B
    assert "O(n log n)" in entry.question
    assert entry.task_name == "computer_science"
    assert len(entry.raw_choices) == 6


def test_tmmlu_process_entry(tmmlu_parser):
    """Test processing entries in TMMLU+ parser."""
    test_row = {
        "question": "什麼是台灣最高的山峰?",
        "A": "玉山",
        "B": "阿里山",
        "C": "合歡山",
        "D": "雪山",
        "answer": "A",
        "subject": "geography_of_taiwan",
    }

    entry = tmmlu_parser.process_entry(test_row, task_name="geography_of_taiwan")
    assert isinstance(entry, MMLUParseEntry)
    assert entry.answer == "A"
    assert entry.raw_choices == ["玉山", "阿里山", "合歡山", "雪山"]
    assert entry.task_name == "geography_of_taiwan"


@pytest.mark.parametrize(
    "parser_fixture,expected_tasks,expected_source",
    [
        ("base_parser", 57, "cais/mmlu"),
        ("redux_parser", 30, "edinburgh-dawg/mmlu-redux"),
        ("tmmlu_parser", 66, "ikala/tmmluplus"),
        ("mmlu_pro_parser", 1, "TIGER-Lab/MMLU-Pro"),
    ],
)
def test_parser_initialization(
    request, parser_fixture, expected_tasks, expected_source
):
    """Test initialization of different MMLU parser variants."""
    parser = request.getfixturevalue(parser_fixture)
    assert len(parser.task_names) == expected_tasks
    assert parser._data_source == expected_source
    assert (
        parser.get_huggingface_link
        == f"https://huggingface.co/datasets/{expected_source}"
    )


@pytest.mark.integration
def test_load_dataset(base_parser):
    """Test loading the MMLU dataset."""
    base_parser.load(task_name="anatomy", split="test")
    assert base_parser.raw_data is not None
    assert base_parser.split_names == ["test"]
    assert base_parser._current_task == "anatomy"


def test_parser_string_representation(base_parser):
    """Test string representation of MMLU parser."""
    repr_str = str(base_parser)
    assert "MMLUDatasetParser" in repr_str
    assert "cais/mmlu" in repr_str
    assert "not loaded" in repr_str


@pytest.mark.integration
def test_different_splits_parsing(base_parser):
    """Test parsing different splits of the dataset."""
    # Load and parse test split
    base_parser.load(task_name="anatomy", split="test")
    base_parser.parse(split_names="test", force=True)
    test_count = len(base_parser.get_parsed_data)

    # Load and parse validation split
    base_parser.load(task_name="anatomy", split="validation")
    base_parser.parse(split_names="validation", force=True)
    val_count = len(base_parser.get_parsed_data)

    assert test_count > 0
    assert val_count > 0
    assert test_count != val_count


def test_base_mmlu_dataset_description(base_parser):
    """Test dataset description for base MMLU."""
    description = base_parser.get_dataset_description()

    assert description.name == "Massive Multitask Language Understanding (MMLU)"
    assert "cais/mmlu" in description.source
    assert description.language == "English"

    # Check characteristics
    assert "57 subjects" in description.characteristics.lower()

    # Check citation
    assert "hendryckstest2021" in description.citation


def test_mmlu_redux_dataset_description(redux_parser):
    """Test dataset description for MMLU Redux."""
    description = redux_parser.get_dataset_description()

    assert description.name == "MMLU Redux"
    assert "manually re-annotated" in description.purpose.lower()
    assert "edinburgh-dawg/mmlu-redux" in description.source
    assert description.language == "English"

    # Check characteristics
    assert "3,000" in description.characteristics


def test_tmmlu_plus_dataset_description(tmmlu_parser):
    """Test dataset description for TMMLU+."""
    description = tmmlu_parser.get_dataset_description()

    assert "ikala/tmmluplus" in description.source
    assert description.language == "Traditional Chinese"

    # Check characteristics
    assert "66 subjects" in description.characteristics.lower()

    # Check citation
    assert "ikala2024improved" in description.citation


def test_mmlu_pro_dataset_description(mmlu_pro_parser):
    """Test dataset description for MMLU Pro."""
    description = mmlu_pro_parser.get_dataset_description()

    assert description.name == "MMLU Pro"
    assert "challenging" in description.purpose.lower()
    assert "TIGER-Lab/MMLU-Pro" in description.source
    assert description.language == "English"


def test_base_mmlu_evaluation_metrics(base_parser):
    """Test evaluation metrics for base MMLU."""
    metrics = base_parser.get_evaluation_metrics()

    assert len(metrics) >= 3
    metric_names = {m.name for m in metrics}

    assert "accuracy" in metric_names
    assert "subject_accuracy" in metric_names
    assert "category_accuracy" in metric_names

    accuracy_metric = next(m for m in metrics if m.name == "accuracy")
    assert accuracy_metric.type == "classification"
    assert accuracy_metric.primary is True
    assert "multiple-choice" in accuracy_metric.description.lower()


def test_mmlu_redux_evaluation_metrics(redux_parser):
    """Test evaluation metrics for MMLU Redux."""
    metrics = redux_parser.get_evaluation_metrics()

    metric_names = {m.name for m in metrics}
    assert "question_clarity" in metric_names


def test_tmmlu_plus_evaluation_metrics(tmmlu_parser):
    """Test evaluation metrics for TMMLU+."""
    metrics = tmmlu_parser.get_evaluation_metrics()

    metric_names = {m.name for m in metrics}
    assert "difficulty_analysis" in metric_names


def test_mmlu_pro_evaluation_metrics(mmlu_pro_parser):
    """Test evaluation metrics for MMLU Pro."""
    metrics = mmlu_pro_parser.get_evaluation_metrics()

    metric_names = {m.name for m in metrics}
    assert "reasoning_analysis" in metric_names
    assert "prompt_robustness" in metric_names