File size: 4,314 Bytes
dd0b07f
 
 
2e6d41b
 
 
 
 
 
dd0b07f
 
 
 
 
 
 
 
 
 
 
 
 
 
0450c4e
dd0b07f
 
 
 
 
 
 
 
 
 
 
0450c4e
dd0b07f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0450c4e
dd0b07f
 
0450c4e
dd0b07f
 
 
 
 
 
 
2e6d41b
 
 
 
 
 
 
a06316f
2e6d41b
 
 
 
 
 
 
8b1be45
2e6d41b
 
 
 
 
 
 
 
 
 
 
 
 
 
dd0b07f
 
 
 
 
 
 
 
 
 
 
 
 
 
0450c4e
dd0b07f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from dataclasses import dataclass
from typing import Any, Final

from llmdataparser.base_parser import (
    DatasetDescription,
    EvaluationMetric,
    HuggingFaceDatasetParser,
    HuggingFaceParseEntry,
)

TW_LEGAL_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
TW_LEGAL_VALID_ANSWER_STR: Final[str] = ", ".join(sorted(TW_LEGAL_VALID_ANSWERS))


@dataclass(frozen=True, kw_only=True, slots=True)
class TWLegalParseEntry(HuggingFaceParseEntry):
    """Custom entry class for Taiwan Legal Benchmark, with fields specific to this dataset parser."""

    raw_choices: list[str]

    @classmethod
    def create(
        cls,
        question: str,
        answer: str,
        raw_question: str,
        raw_choices: list[str],
        raw_answer: str,
        task_name: str,
    ) -> "TWLegalParseEntry":
        if answer not in TW_LEGAL_VALID_ANSWERS:
            raise ValueError(
                f"Invalid answer_letter '{answer}'; must be one of {TW_LEGAL_VALID_ANSWER_STR}"
            )
        return cls(
            question=question,
            answer=answer,
            raw_question=raw_question,
            raw_answer=raw_answer,
            raw_choices=raw_choices,
            task_name=task_name,
        )


class TWLegalDatasetParser(HuggingFaceDatasetParser[TWLegalParseEntry]):
    """Parser for the Taiwan Legal Benchmark dataset."""

    _data_source = "lianghsun/tw-legal-benchmark-v1"
    _default_task = "default"
    _task_names = ["default"]

    def process_entry(
        self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
    ) -> TWLegalParseEntry:
        """Process a single Taiwan Legal Benchmark entry."""
        # Extract choices in order
        task = task_name or self._get_current_task(row)
        raw_choices = [row["A"], row["B"], row["C"], row["D"]]
        choices = "\n".join(
            f"{chr(65 + i)}. {choice}" for i, choice in enumerate(raw_choices)
        )
        raw_question = row["question"]
        raw_answer = row["answer"]

        question = f"Question: {raw_question}\n{choices}\nAnswer:"

        return TWLegalParseEntry.create(
            question=question,
            answer=raw_answer,
            raw_question=raw_question,
            raw_choices=raw_choices,
            raw_answer=raw_answer,
            task_name=task,
        )

    def get_dataset_description(self) -> DatasetDescription:
        """Returns description of the Taiwan Legal Benchmark dataset."""
        return DatasetDescription.create(
            name="Taiwan Legal Benchmark",
            language="Traditional Chinese",
            purpose="Evaluate models on Taiwan-specific legal knowledge and understanding",
            source="Taiwan Bar Examination questions",
            category=["Taiwan", "General Knowledge and Reasoning", "Legal"],
            format="Multiple choice questions (A/B/C/D)",
            characteristics=(
                "Contains questions from Taiwan's bar examination, testing understanding "
                "of Taiwan's legal system, terminology, and concepts"
            ),
            citation="""
                url={https://huggingface.co/datasets/lianghsun/tw-legal-benchmark-v1}
            """,
        )

    def get_evaluation_metrics(self) -> list[EvaluationMetric]:
        """Returns recommended evaluation metrics for Taiwan Legal Benchmark."""
        return [
            EvaluationMetric.create(
                name="accuracy",
                type="classification",
                description="Overall percentage of correctly answered legal questions",
                implementation="datasets.load_metric('accuracy')",
                primary=True,
            ),
        ]


if __name__ == "__main__":
    # Example usage
    parser = TWLegalDatasetParser()
    parser.load()
    parser.parse()

    # Get parsed data with correct type
    parsed_data = parser.get_parsed_data

    # Print example entry
    if parsed_data:
        example = parsed_data[0]
        print("\nExample parsed entry:")
        print(f"Question: {example.question}")
        print("Choices:")
        for i, choice in enumerate(example.raw_choices):
            print(f"{chr(65 + i)}. {choice}")
        print(f"Correct Answer: {example.answer}")
        print(f"Task Name: {example.task_name}")