File size: 6,875 Bytes
0157dfd
 
 
 
 
 
af461f3
0157dfd
 
51a31d4
51dabd6
0157dfd
af461f3
ab5dfc2
0157dfd
b06298d
af461f3
0157dfd
 
 
 
 
325e3c6
b06298d
ab5dfc2
f2e3e47
0157dfd
 
 
 
 
51dabd6
 
51a31d4
e9df5ab
 
51a31d4
 
b06298d
0157dfd
b06298d
 
0157dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b06298d
 
0157dfd
 
b06298d
 
 
 
0157dfd
 
 
 
 
 
 
 
 
b06298d
 
0157dfd
 
 
b06298d
 
0157dfd
 
 
 
 
b06298d
0157dfd
b06298d
 
0157dfd
b06298d
 
0157dfd
 
 
 
 
 
b06298d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7158e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from dotenv import load_dotenv
# needs to happen as very first thing, otherwise HF ignores env vars
load_dotenv()

import os
import pandas as pd

from dataclasses import dataclass
from typing import Dict, cast
from datasets import DatasetDict, load_dataset

from src.readers.base_reader import Reader
from src.evaluation import evaluate
from src.readers.dpr_reader import DprReader
from src.readers.longformer_reader import LongformerReader
from src.retrievers.base_retriever import Retriever
from src.retrievers.es_retriever import ESRetriever
from src.retrievers.faiss_retriever import (
    FaissRetriever,
    FaissRetrieverOptions
)
from src.utils.log import logger
from src.utils.preprocessing import context_to_reader_input
from src.utils.timing import get_times, timeit


@dataclass
class Experiment:
    retriever: Retriever
    reader: Reader


if __name__ == '__main__':
    dataset_name = "GroNLP/ik-nlp-22_slp"
    paragraphs = cast(DatasetDict, load_dataset(
        "GroNLP/ik-nlp-22_slp", "paragraphs"))
    questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))

    # Only doing a few questions for speed
    subset_idx = len(questions["test"])
    questions_test = questions["test"][:subset_idx]

    experiments: Dict[str, Experiment] = {
        "faiss_dpr": Experiment(
            retriever=FaissRetriever(
                paragraphs,
                FaissRetrieverOptions.dpr("./src/models/dpr.faiss")),
            reader=DprReader()
        ),
        "faiss_longformer": Experiment(
            retriever=FaissRetriever(
                paragraphs,
                FaissRetrieverOptions.longformer("./src/models/longformer.faiss")),
            reader=LongformerReader()
        ),
        "es_dpr": Experiment(
            retriever=ESRetriever(paragraphs),
            reader=DprReader()
        ),
        "es_longformer": Experiment(
            retriever=ESRetriever(paragraphs),
            reader=LongformerReader()
        ),
    }

    for experiment_name, experiment in experiments.items():
        logger.info(f"Running experiment {experiment_name}...")
        for idx in range(subset_idx):
            question = questions_test["question"][idx]
            answer = questions_test["answer"][idx]

            retrieve_timer = timeit(f"{experiment_name}.retrieve")
            t_retrieve = retrieve_timer(experiment.retriever.retrieve)

            read_timer = timeit(f"{experiment_name}.read")
            t_read = read_timer(experiment.reader.read)

            print(f"\x1b[1K\r[{idx+1:03}] - \"{question}\"", end='')

            scores, context = t_retrieve(question, 5)
            reader_input = context_to_reader_input(context)

            # workaround so we can use the decorator with a dynamic name for
            # time recording
            answers = t_read(question, reader_input, 5)

            # Calculate softmaxed scores for readable output
            # sm = torch.nn.Softmax(dim=0)
            # document_scores = sm(torch.Tensor(
            #     [pred.relevance_score for pred in answers]))
            # span_scores = sm(torch.Tensor(
            #     [pred.span_score for pred in answers]))

            # print_answers(answers, scores, context)

            # TODO evaluation and storing of results
        print()

    times = get_times()

    df = pd.DataFrame(times)
    os.makedirs("./results/", exist_ok=True)
    df.to_csv("./results/timings.csv")


    # TODO evaluation and storing of results

    # # Initialize retriever
    # retriever = FaissRetriever(paragraphs)
    # # retriever = ESRetriever(paragraphs)

    # # Retrieve example
    # # random.seed(111)
    # random_index = random.randint(0, len(questions_test["question"])-1)
    # example_q = questions_test["question"][random_index]
    # example_a = questions_test["answer"][random_index]

    # scores, result = retriever.retrieve(example_q)
    # reader_input = context_to_reader_input(result)

    # # TODO: use new code from query.py to clean this up
    # # Initialize reader
    # answers = reader.read(example_q, reader_input)

    # # Calculate softmaxed scores for readable output
    # sm = torch.nn.Softmax(dim=0)
    # document_scores = sm(torch.Tensor(
    #     [pred.relevance_score for pred in answers]))
    # span_scores = sm(torch.Tensor(
    #     [pred.span_score for pred in answers]))

    # print(example_q)
    # for answer_i, answer in enumerate(answers):
    #     print(f"[{answer_i + 1}]: {answer.text}")
    #     print(f"\tDocument {answer.doc_id}", end='')
    #     print(f"\t(score {document_scores[answer_i] * 100:.02f})")
    #     print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
    #     print(f"\t(score {span_scores[answer_i] * 100:.02f})")
    #     print()  # Newline

    # # print(f"Example q: {example_q} answer: {result['text'][0]}")

    # # for i, score in enumerate(scores):
    # #     print(f"Result {i+1} (score: {score:.02f}):")
    # #     print(result['text'][i])

    # # Determine best answer we want to evaluate
    # highest, highest_index = 0, 0
    # for i, value in enumerate(span_scores):
    #     if value + document_scores[i] > highest:
    #         highest = value + document_scores[i]
    #         highest_index = i

    # # Retrieve exact match and F1-score
    # exact_match, f1_score = evaluate(
    #     example_a, answers[highest_index].text)
    # print(f"Gold answer: {example_a}\n"
    #       f"Predicted answer: {answers[highest_index].text}\n"
    #       f"Exact match: {exact_match:.02f}\n"
    #       f"F1-score: {f1_score:.02f}")

    # Calculate overall performance
    # total_f1 = 0
    # total_exact = 0
    # total_len = len(questions_test["question"])
    # start_time = time.time()
    # for i, question in enumerate(questions_test["question"]):
    #     print(question)
    #     answer = questions_test["answer"][i]
    #     print(answer)
    #
    #     scores, result = retriever.retrieve(question)
    #     reader_input = result_to_reader_input(result)
    #     answers = reader.read(question, reader_input)
    #
    #     document_scores = sm(torch.Tensor(
    #         [pred.relevance_score for pred in answers]))
    #     span_scores = sm(torch.Tensor(
    #         [pred.span_score for pred in answers]))
    #
    #     highest, highest_index = 0, 0
    #     for j, value in enumerate(span_scores):
    #         if value + document_scores[j] > highest:
    #             highest = value + document_scores[j]
    #             highest_index = j
    #     print(answers[highest_index])
    #     exact_match, f1_score = evaluate(answer, answers[highest_index].text)
    #     total_f1 += f1_score
    #     total_exact += exact_match
    # print(f"Total time:", round(time.time() - start_time, 2), "seconds.")
    # print(total_f1)
    # print(total_exact)
    # print(total_f1/total_len)