LLMEval-Dataset-Parser / tests /test_mgsm_parser.py
JeffYang52415's picture
refactor: remove system prompt
0450c4e unverified
import pytest
from llmdataparser.mgsm_parser import MGSMDatasetParser, MGSMParseEntry
@pytest.fixture
def mgsm_parser():
"""Create a MGSM parser instance for testing."""
return MGSMDatasetParser()
@pytest.fixture
def loaded_mgsm_parser(mgsm_parser):
"""Create and load a MGSM parser instance with test split."""
mgsm_parser.load(task_name="en", split="test")
return mgsm_parser
@pytest.fixture
def sample_mgsm_entries():
"""Create sample MGSM dataset entries for testing."""
return [
{
"question": "John has 5 apples and buys 3 more. How many apples does he have now?",
"answer": "Let's solve step by step:\n1) Initial apples = 5\n2) Bought apples = 3\n3) Total = 5 + 3 = 8\nJohn has 8 apples now.",
"answer_number": 8,
"equation_solution": "5 + 3 = 8",
"language": "en",
},
{
"question": "Juan tiene 5 manzanas y compra 3 más. ¿Cuántas manzanas tiene ahora?",
"answer": "Resolvamos paso a paso:\n1) Manzanas iniciales = 5\n2) Manzanas compradas = 3\n3) Total = 5 + 3 = 8\nJuan tiene 8 manzanas ahora.",
"answer_number": 8,
"equation_solution": "5 + 3 = 8",
"language": "es",
},
{
"question": "ジョンはリンゴを5個持っていて、さらに3個買います。今何個持っていますか?",
"answer": None, # Testing case with missing detailed answer
"answer_number": 8,
"equation_solution": "5 + 3 = 8",
"language": "ja",
},
]
def test_mgsm_parse_entry_creation_valid():
"""Test valid creation of MGSMParseEntry with all fields."""
entry = MGSMParseEntry.create(
question="Test question",
answer="Test answer",
raw_question="Test question",
raw_answer="Test answer",
numerical_answer=42,
equation_solution="21 * 2 = 42",
task_name="en",
language="en",
)
assert isinstance(entry, MGSMParseEntry)
assert entry.question == "Test question"
assert entry.answer == "Test answer"
assert entry.raw_question == "Test question"
assert entry.raw_answer == "Test answer"
assert entry.numerical_answer == 42
assert entry.equation_solution == "21 * 2 = 42"
assert entry.task_name == "en"
assert entry.language == "en"
def test_process_entry_with_detailed_answer(mgsm_parser, sample_mgsm_entries):
"""Test processing entry with detailed answer in English."""
entry = mgsm_parser.process_entry(sample_mgsm_entries[0], task_name="en")
assert isinstance(entry, MGSMParseEntry)
assert entry.numerical_answer == 8
assert entry.equation_solution == "5 + 3 = 8"
assert "step by step" in entry.answer
assert entry.language == "en"
assert entry.task_name == "en"
def test_process_entry_without_detailed_answer(mgsm_parser, sample_mgsm_entries):
"""Test processing entry without detailed answer (Japanese)."""
entry = mgsm_parser.process_entry(sample_mgsm_entries[2], task_name="ja")
assert isinstance(entry, MGSMParseEntry)
assert entry.numerical_answer == 8
assert entry.equation_solution == "5 + 3 = 8"
assert entry.answer == "8" # Should use numerical_answer as string
assert entry.language == "ja"
assert entry.task_name == "ja"
def test_process_entry_spanish(mgsm_parser, sample_mgsm_entries):
"""Test processing Spanish entry."""
entry = mgsm_parser.process_entry(sample_mgsm_entries[1], task_name="es")
assert isinstance(entry, MGSMParseEntry)
assert entry.numerical_answer == 8
assert entry.equation_solution == "5 + 3 = 8"
assert "paso a paso" in entry.answer # Spanish for "step by step"
assert entry.language == "es"
assert entry.task_name == "es"
def test_mgsm_parser_initialization(mgsm_parser):
"""Test MGSM parser initialization and properties."""
assert isinstance(mgsm_parser.task_names, list)
assert len(mgsm_parser.task_names) == 11 # 11 supported languages
assert mgsm_parser._data_source == "juletxara/mgsm"
assert mgsm_parser._default_task == "en"
assert all(lang in mgsm_parser.task_names for lang in ["en", "es", "ja", "zh"])
assert (
mgsm_parser.get_huggingface_link
== "https://huggingface.co/datasets/juletxara/mgsm"
)
@pytest.mark.integration
def test_load_dataset(loaded_mgsm_parser):
"""Test loading the MGSM dataset."""
assert loaded_mgsm_parser.raw_data is not None
assert loaded_mgsm_parser.split_names == ["test"]
assert loaded_mgsm_parser._current_task == "en"
def test_parser_string_representation(loaded_mgsm_parser):
"""Test string representation of MGSM parser."""
repr_str = str(loaded_mgsm_parser)
assert "MGSMDatasetParser" in repr_str
assert "juletxara/mgsm" in repr_str
assert "en" in repr_str
assert "loaded" in repr_str
@pytest.mark.integration
def test_different_languages_parsing(mgsm_parser):
"""Test parsing different language versions."""
# Load and parse English
mgsm_parser.load(task_name="en", split="test")
mgsm_parser.parse(split_names="test", force=True)
en_count = len(mgsm_parser.get_parsed_data)
# Load and parse Spanish
mgsm_parser.load(task_name="es", split="test")
mgsm_parser.parse(split_names="test", force=True)
es_count = len(mgsm_parser.get_parsed_data)
assert en_count > 0
assert es_count > 0
assert en_count == es_count # Should have same number of problems in each language
@pytest.mark.parametrize("language", ["en", "es", "ja", "zh", "ru"])
def test_supported_languages(mgsm_parser, language):
"""Test that each supported language can be processed."""
test_entry = {
"question": f"Test question in {language}",
"answer": f"Test answer in {language}",
"answer_number": 42,
"equation_solution": "21 * 2 = 42",
}
entry = mgsm_parser.process_entry(test_entry, task_name=language)
assert entry.language == language
assert entry.task_name == language
assert entry.numerical_answer == 42
def test_get_dataset_description(mgsm_parser):
"""Test dataset description generation."""
description = mgsm_parser.get_dataset_description()
assert description.name == "Multilingual Grade School Math (MGSM)"
assert "multilingual chain-of-thought reasoning" in description.purpose.lower()
assert "juletxara/mgsm" in description.source
assert description.language == "Multilingual (11 languages)"
assert "mathematical reasoning" in description.characteristics.lower()
# Check citations
assert "shi2022language" in description.citation
assert "cobbe2021gsm8k" in description.citation
# Check additional info
assert description.additional_info is not None
assert len(description.additional_info["languages"]) == 11
assert "English" in description.additional_info["languages"]
assert "Chinese" in description.additional_info["languages"]
def test_get_evaluation_metrics(mgsm_parser):
"""Test evaluation metrics generation."""
metrics = mgsm_parser.get_evaluation_metrics()
# Check total number of metrics
assert len(metrics) == 4
# Check primary metrics
primary_metrics = [m for m in metrics if m.primary]
assert len(primary_metrics) == 3
# Verify specific metrics exist with correct properties
metric_names = {m.name for m in metrics}
assert "exact_match" in metric_names
assert "solution_validity" in metric_names
assert "step_accuracy" in metric_names
assert "cross_lingual_consistency" in metric_names
# Check specific metric properties
exact_match_metric = next(m for m in metrics if m.name == "exact_match")
assert exact_match_metric.type == "string"
assert exact_match_metric.primary is True
assert "numerical answers" in exact_match_metric.description.lower()
assert "custom_exact_match" in exact_match_metric.implementation
solution_metric = next(m for m in metrics if m.name == "solution_validity")
assert solution_metric.type == "text"
assert solution_metric.primary is True
assert "mathematically valid" in solution_metric.description.lower()
assert "custom_solution_validator" in solution_metric.implementation
step_metric = next(m for m in metrics if m.name == "step_accuracy")
assert step_metric.type == "numerical"
assert step_metric.primary is True
assert "calculation steps" in step_metric.description.lower()
assert "custom_step_accuracy" in step_metric.implementation