File size: 6,404 Bytes
6ed7950 0450c4e 6ed7950 0450c4e 6ed7950 0450c4e 6ed7950 58d5612 6ed7950 58d5612 6ed7950 58d5612 6ed7950 2e6d41b 6ed7950 58d5612 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import pytest
from llmdataparser.tmlu_parser import TMLUDatasetParser, TMLUParseEntry
@pytest.fixture
def tmlu_parser():
"""Create a TMLU parser instance for testing."""
return TMLUDatasetParser()
@pytest.fixture
def sample_tmlu_entries():
"""Create sample TMLU dataset entries for testing."""
return [
{
"question": "閱讀下文,選出依序最適合填入□內的選項:",
"A": "張揚/綢繆未雨/奏疏",
"B": "抽搐/煮繭抽絲/奏疏",
"C": "張揚/煮繭抽絲/進貢",
"D": "抽搐/綢繆未雨/進貢",
"answer": "B",
"explanation": "根據文意,選項B最為恰當。",
"metadata": {
"timestamp": "2023-10-09T18:27:20.304623",
"source": "AST chinese - 108",
"explanation_source": "",
},
},
{
"question": "下列何者是質數?",
"A": "21",
"B": "27",
"C": "31",
"D": "33",
"answer": "C",
"explanation": "31是質數,其他選項都是合數。",
"metadata": {
"timestamp": "2023-10-09T18:27:20.304623",
"source": "AST mathematics - 108",
"explanation_source": "",
},
},
]
def test_tmlu_parse_entry_creation_valid():
"""Test valid creation of TMLUParseEntry."""
entry = TMLUParseEntry.create(
question="Test question",
answer="A",
raw_question="Test question",
raw_choices=["choice1", "choice2", "choice3", "choice4"],
raw_answer="A",
task_name="AST_chinese",
explanation="Test explanation",
metadata={"source": "test"},
)
assert isinstance(entry, TMLUParseEntry)
assert entry.question == "Test question"
assert entry.answer == "A"
assert entry.raw_choices == ["choice1", "choice2", "choice3", "choice4"]
assert entry.explanation == "Test explanation"
assert entry.metadata == {"source": "test"}
@pytest.mark.parametrize("invalid_answer", ["E", "F", "1", "", None])
def test_tmlu_parse_entry_creation_invalid(invalid_answer):
"""Test invalid answer handling in TMLUParseEntry creation."""
with pytest.raises(
ValueError, match="Invalid answer_letter.*must be one of A, B, C, D"
):
TMLUParseEntry.create(
question="Test question",
answer=invalid_answer,
raw_question="Test question",
raw_choices=["choice1", "choice2", "choice3", "choice4"],
raw_answer=invalid_answer,
task_name="AST_chinese",
)
def test_process_entry(tmlu_parser, sample_tmlu_entries):
"""Test processing entries in TMLU parser."""
entry = tmlu_parser.process_entry(sample_tmlu_entries[0], task_name="AST_chinese")
assert isinstance(entry, TMLUParseEntry)
assert entry.answer == "B"
assert entry.task_name == "AST_chinese"
assert len(entry.raw_choices) == 4
assert entry.explanation == "根據文意,選項B最為恰當。"
assert "AST chinese - 108" in entry.metadata["source"]
def test_tmlu_parser_initialization(tmlu_parser):
"""Test TMLU parser initialization and properties."""
assert isinstance(tmlu_parser.task_names, list)
assert len(tmlu_parser.task_names) == 37 # Total number of tasks
assert tmlu_parser._data_source == "miulab/tmlu"
assert tmlu_parser._default_task == "AST_chinese"
assert "AST_chinese" in tmlu_parser.task_names
assert "GSAT_mathematics" in tmlu_parser.task_names
assert (
tmlu_parser.get_huggingface_link
== "https://huggingface.co/datasets/miulab/tmlu"
)
@pytest.mark.integration
def test_load_dataset(tmlu_parser):
"""Test loading the TMLU dataset."""
tmlu_parser.load(task_name="AST_chinese", split="test")
assert tmlu_parser.raw_data is not None
assert tmlu_parser.split_names == ["test"]
assert tmlu_parser._current_task == "AST_chinese"
def test_parser_string_representation(tmlu_parser):
"""Test string representation of TMLU parser."""
repr_str = str(tmlu_parser)
assert "TMLUDatasetParser" in repr_str
assert "miulab/tmlu" in repr_str
assert "not loaded" in repr_str
@pytest.mark.integration
def test_different_tasks_parsing(tmlu_parser):
"""Test parsing different tasks of the dataset."""
# Load and parse AST_chinese
tmlu_parser.load(task_name="AST_chinese", split="test")
tmlu_parser.parse(split_names="test", force=True)
chinese_count = len(tmlu_parser.get_parsed_data)
# Load and parse AST_mathematics
tmlu_parser.load(task_name="AST_mathematics", split="test")
tmlu_parser.parse(split_names="test", force=True)
math_count = len(tmlu_parser.get_parsed_data)
assert chinese_count > 0
assert math_count > 0
def test_metadata_handling(tmlu_parser, sample_tmlu_entries):
"""Test proper handling of metadata in entries."""
entry = tmlu_parser.process_entry(sample_tmlu_entries[0])
assert "timestamp" in entry.metadata
assert "source" in entry.metadata
assert "explanation_source" in entry.metadata
assert entry.metadata["source"] == "AST chinese - 108"
def test_get_dataset_description(tmlu_parser):
"""Test dataset description generation."""
description = tmlu_parser.get_dataset_description()
assert description.name == "Taiwan Multiple-choice Language Understanding (TMLU)"
assert description.language == "Traditional Chinese"
assert "Taiwan-specific educational" in description.purpose
assert "Various Taiwan standardized tests" in description.source
assert description.format == "Multiple choice questions (A/B/C/D)"
assert "Advanced Subjects Test (AST)" in description.characteristics
assert "DBLP:journals/corr/abs-2403-20180" in description.citation
def test_get_evaluation_metrics(tmlu_parser):
"""Test evaluation metrics generation."""
metrics = tmlu_parser.get_evaluation_metrics()
assert len(metrics) == 2 # Check total number of metrics
# Check primary metrics
primary_metrics = [m for m in metrics if m.primary]
assert len(primary_metrics) == 2
assert any(m.name == "accuracy" for m in primary_metrics)
assert any(m.name == "per_subject_accuracy" for m in primary_metrics)
|