File size: 4,148 Bytes
af461f3
 
 
b7158e7
af461f3
 
 
51a31d4
f2e3e47
51dabd6
af461f3
ab5dfc2
af461f3
ab5dfc2
51a31d4
325e3c6
ab5dfc2
 
f2e3e47
 
ab5dfc2
51dabd6
 
51a31d4
e9df5ab
 
51a31d4
 
 
 
51dabd6
e9df5ab
 
ab5dfc2
 
1fb8ae3
ade7e41
 
 
ab5dfc2
ade7e41
325e3c6
ab5dfc2
e9df5ab
ab5dfc2
 
 
 
 
 
 
 
 
 
51dabd6
ab5dfc2
 
 
 
 
 
 
 
51a31d4
ab5dfc2
51dabd6
ab5dfc2
 
 
51dabd6
ade7e41
 
 
 
 
 
 
 
 
 
 
 
 
 
b7158e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import random
from typing import cast
import time

import torch
import transformers
from datasets import DatasetDict, load_dataset
from dotenv import load_dotenv

from src.evaluation import evaluate
from src.readers.dpr_reader import DprReader
from src.retrievers.es_retriever import ESRetriever
from src.retrievers.faiss_retriever import FaissRetriever
from src.utils.log import get_logger
from src.utils.preprocessing import context_to_reader_input

logger = get_logger()

load_dotenv()
transformers.logging.set_verbosity_error()

if __name__ == '__main__':
    dataset_name = "GroNLP/ik-nlp-22_slp"
    paragraphs = cast(DatasetDict, load_dataset(
        "GroNLP/ik-nlp-22_slp", "paragraphs"))
    questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))

    questions_test = questions["test"]

    # Initialize retriever
    retriever = FaissRetriever(paragraphs)
    #retriever = ESRetriever(paragraphs)

    # Retrieve example
    # random.seed(111)
    random_index = random.randint(0, len(questions_test["question"])-1)
    example_q = questions_test["question"][random_index]
    example_a = questions_test["answer"][random_index]

    scores, result = retriever.retrieve(example_q)
    reader_input = context_to_reader_input(result)

    # TODO: use new code from query.py to clean this up
    # Initialize reader
    reader = DprReader()
    answers = reader.read(example_q, reader_input)

    # Calculate softmaxed scores for readable output
    sm = torch.nn.Softmax(dim=0)
    document_scores = sm(torch.Tensor(
        [pred.relevance_score for pred in answers]))
    span_scores = sm(torch.Tensor(
        [pred.span_score for pred in answers]))

    print(example_q)
    for answer_i, answer in enumerate(answers):
        print(f"[{answer_i + 1}]: {answer.text}")
        print(f"\tDocument {answer.doc_id}", end='')
        print(f"\t(score {document_scores[answer_i] * 100:.02f})")
        print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
        print(f"\t(score {span_scores[answer_i] * 100:.02f})")
        print()  # Newline

    # print(f"Example q: {example_q} answer: {result['text'][0]}")

    # for i, score in enumerate(scores):
    #     print(f"Result {i+1} (score: {score:.02f}):")
    #     print(result['text'][i])

    # Determine best answer we want to evaluate
    highest, highest_index = 0, 0
    for i, value in enumerate(span_scores):
        if value + document_scores[i] > highest:
            highest = value + document_scores[i]
            highest_index = i

    # Retrieve exact match and F1-score
    exact_match, f1_score = evaluate(
        example_a, answers[highest_index].text)
    print(f"Gold answer: {example_a}\n"
          f"Predicted answer: {answers[highest_index].text}\n"
          f"Exact match: {exact_match:.02f}\n"
          f"F1-score: {f1_score:.02f}")

    # Calculate overall performance
    # total_f1 = 0
    # total_exact = 0
    # total_len = len(questions_test["question"])
    # start_time = time.time()
    # for i, question in enumerate(questions_test["question"]):
    #     print(question)
    #     answer = questions_test["answer"][i]
    #     print(answer)
    #
    #     scores, result = retriever.retrieve(question)
    #     reader_input = result_to_reader_input(result)
    #     answers = reader.read(question, reader_input)
    #
    #     document_scores = sm(torch.Tensor(
    #         [pred.relevance_score for pred in answers]))
    #     span_scores = sm(torch.Tensor(
    #         [pred.span_score for pred in answers]))
    #
    #     highest, highest_index = 0, 0
    #     for j, value in enumerate(span_scores):
    #         if value + document_scores[j] > highest:
    #             highest = value + document_scores[j]
    #             highest_index = j
    #     print(answers[highest_index])
    #     exact_match, f1_score = evaluate(answer, answers[highest_index].text)
    #     total_f1 += f1_score
    #     total_exact += exact_match
    # print(f"Total time:", round(time.time() - start_time, 2), "seconds.")
    # print(total_f1)
    # print(total_exact)
    # print(total_f1/total_len)