File size: 4,372 Bytes
492106d
 
0157dfd
 
 
 
 
 
af461f3
492106d
 
51a31d4
51dabd6
0157dfd
af461f3
ab5dfc2
0157dfd
b06298d
af461f3
0157dfd
 
 
 
 
325e3c6
b06298d
ab5dfc2
f2e3e47
492106d
 
 
0157dfd
 
 
 
492106d
 
0157dfd
51dabd6
 
51a31d4
e9df5ab
 
51a31d4
 
b06298d
0157dfd
b06298d
 
0157dfd
 
 
 
 
492106d
 
0157dfd
 
 
 
 
492106d
 
0157dfd
 
 
492106d
 
0157dfd
 
 
492106d
 
0157dfd
b06298d
 
0157dfd
 
b06298d
 
 
 
492106d
 
0157dfd
 
 
 
 
 
 
 
 
b06298d
 
492106d
 
b06298d
492106d
 
 
 
 
 
 
b06298d
0157dfd
b06298d
492106d
 
 
 
 
 
0157dfd
492106d
 
 
 
 
 
 
 
0157dfd
492106d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from collections import namedtuple
from pprint import pprint
from dotenv import load_dotenv
# needs to happen as very first thing, otherwise HF ignores env vars
load_dotenv()

import os
import pandas as pd

from dataclasses import dataclass, field
from typing import Dict, cast, List
from datasets import DatasetDict, load_dataset

from src.readers.base_reader import Reader
from src.evaluation import evaluate
from src.readers.dpr_reader import DprReader
from src.readers.longformer_reader import LongformerReader
from src.retrievers.base_retriever import Retriever
from src.retrievers.es_retriever import ESRetriever
from src.retrievers.faiss_retriever import (
    FaissRetriever,
    FaissRetrieverOptions
)
from src.utils.log import logger
from src.utils.preprocessing import context_to_reader_input
from src.utils.timing import get_times, timeit


ExperimentResult = namedtuple('ExperimentResult', ['correct', 'given'])


@dataclass
class Experiment:
    retriever: Retriever
    reader: Reader
    lm: str
    results: List[ExperimentResult] = field(default_factory=list)


if __name__ == '__main__':
    dataset_name = "GroNLP/ik-nlp-22_slp"
    paragraphs = cast(DatasetDict, load_dataset(
        "GroNLP/ik-nlp-22_slp", "paragraphs"))
    questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))

    # Only doing a few questions for speed
    subset_idx = len(questions["test"])
    questions_test = questions["test"][:subset_idx]

    experiments: Dict[str, Experiment] = {
        "faiss_dpr": Experiment(
            retriever=FaissRetriever(
                paragraphs,
                FaissRetrieverOptions.dpr("./src/models/dpr.faiss")),
            reader=DprReader(),
            lm="dpr"
        ),
        "faiss_longformer": Experiment(
            retriever=FaissRetriever(
                paragraphs,
                FaissRetrieverOptions.longformer("./src/models/longformer.faiss")),
            reader=LongformerReader(),
            lm="longformer"
        ),
        "es_dpr": Experiment(
            retriever=ESRetriever(paragraphs),
            reader=DprReader(),
            lm="dpr"
        ),
        "es_longformer": Experiment(
            retriever=ESRetriever(paragraphs),
            reader=LongformerReader(),
            lm="longformer"
        ),
    }

    for experiment_name, experiment in experiments.items():
        logger.info(f"Running experiment {experiment_name}...")
        for idx in range(subset_idx):
            question = questions_test["question"][idx]
            answer = questions_test["answer"][idx]

            # workaround so we can use the decorator with a dynamic name for
            # time recording
            retrieve_timer = timeit(f"{experiment_name}.retrieve")
            t_retrieve = retrieve_timer(experiment.retriever.retrieve)

            read_timer = timeit(f"{experiment_name}.read")
            t_read = read_timer(experiment.reader.read)

            print(f"\x1b[1K\r[{idx+1:03}] - \"{question}\"", end='')

            scores, context = t_retrieve(question, 5)
            reader_input = context_to_reader_input(context)

            # Requesting 1 answers results in us getting the best answer
            given_answer = t_read(question, reader_input, 1)[0]

            # Save the results so we can evaluate laters
            if experiment.lm == "longformer":
                experiment.results.append(
                    ExperimentResult(answer, given_answer[0]))
            else:
                experiment.results.append(
                    ExperimentResult(answer, given_answer.text))

        print()

    if os.getenv("ENABLE_TIMING", "false").lower() == "true":
        # Save times
        times = get_times()
        df = pd.DataFrame(times)
        os.makedirs("./results/", exist_ok=True)
        df.to_csv("./results/timings.csv")

    f1_results = pd.DataFrame(columns=experiments.keys())
    em_results = pd.DataFrame(columns=experiments.keys())
    for experiment_name, experiment in experiments.items():
        em, f1 = zip(*list(map(
            lambda r: evaluate(r.correct, r.given), experiment.results
        )))
        em_results[experiment_name] = em
        f1_results[experiment_name] = f1

    os.makedirs("./results/", exist_ok=True)
    f1_results.to_csv("./results/f1_scores.csv")
    em_results.to_csv("./results/em_scores.csv")