File size: 4,148 Bytes
af461f3 b7158e7 af461f3 51a31d4 f2e3e47 51dabd6 af461f3 ab5dfc2 af461f3 ab5dfc2 51a31d4 325e3c6 ab5dfc2 f2e3e47 ab5dfc2 51dabd6 51a31d4 e9df5ab 51a31d4 51dabd6 e9df5ab ab5dfc2 1fb8ae3 ade7e41 ab5dfc2 ade7e41 325e3c6 ab5dfc2 e9df5ab ab5dfc2 51dabd6 ab5dfc2 51a31d4 ab5dfc2 51dabd6 ab5dfc2 51dabd6 ade7e41 b7158e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import os
import random
from typing import cast
import time
import torch
import transformers
from datasets import DatasetDict, load_dataset
from dotenv import load_dotenv
from src.evaluation import evaluate
from src.readers.dpr_reader import DprReader
from src.retrievers.es_retriever import ESRetriever
from src.retrievers.faiss_retriever import FaissRetriever
from src.utils.log import get_logger
from src.utils.preprocessing import context_to_reader_input
logger = get_logger()
load_dotenv()
transformers.logging.set_verbosity_error()
if __name__ == '__main__':
dataset_name = "GroNLP/ik-nlp-22_slp"
paragraphs = cast(DatasetDict, load_dataset(
"GroNLP/ik-nlp-22_slp", "paragraphs"))
questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
questions_test = questions["test"]
# Initialize retriever
retriever = FaissRetriever(paragraphs)
#retriever = ESRetriever(paragraphs)
# Retrieve example
# random.seed(111)
random_index = random.randint(0, len(questions_test["question"])-1)
example_q = questions_test["question"][random_index]
example_a = questions_test["answer"][random_index]
scores, result = retriever.retrieve(example_q)
reader_input = context_to_reader_input(result)
# TODO: use new code from query.py to clean this up
# Initialize reader
reader = DprReader()
answers = reader.read(example_q, reader_input)
# Calculate softmaxed scores for readable output
sm = torch.nn.Softmax(dim=0)
document_scores = sm(torch.Tensor(
[pred.relevance_score for pred in answers]))
span_scores = sm(torch.Tensor(
[pred.span_score for pred in answers]))
print(example_q)
for answer_i, answer in enumerate(answers):
print(f"[{answer_i + 1}]: {answer.text}")
print(f"\tDocument {answer.doc_id}", end='')
print(f"\t(score {document_scores[answer_i] * 100:.02f})")
print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
print(f"\t(score {span_scores[answer_i] * 100:.02f})")
print() # Newline
# print(f"Example q: {example_q} answer: {result['text'][0]}")
# for i, score in enumerate(scores):
# print(f"Result {i+1} (score: {score:.02f}):")
# print(result['text'][i])
# Determine best answer we want to evaluate
highest, highest_index = 0, 0
for i, value in enumerate(span_scores):
if value + document_scores[i] > highest:
highest = value + document_scores[i]
highest_index = i
# Retrieve exact match and F1-score
exact_match, f1_score = evaluate(
example_a, answers[highest_index].text)
print(f"Gold answer: {example_a}\n"
f"Predicted answer: {answers[highest_index].text}\n"
f"Exact match: {exact_match:.02f}\n"
f"F1-score: {f1_score:.02f}")
# Calculate overall performance
# total_f1 = 0
# total_exact = 0
# total_len = len(questions_test["question"])
# start_time = time.time()
# for i, question in enumerate(questions_test["question"]):
# print(question)
# answer = questions_test["answer"][i]
# print(answer)
#
# scores, result = retriever.retrieve(question)
# reader_input = result_to_reader_input(result)
# answers = reader.read(question, reader_input)
#
# document_scores = sm(torch.Tensor(
# [pred.relevance_score for pred in answers]))
# span_scores = sm(torch.Tensor(
# [pred.span_score for pred in answers]))
#
# highest, highest_index = 0, 0
# for j, value in enumerate(span_scores):
# if value + document_scores[j] > highest:
# highest = value + document_scores[j]
# highest_index = j
# print(answers[highest_index])
# exact_match, f1_score = evaluate(answer, answers[highest_index].text)
# total_f1 += f1_score
# total_exact += exact_match
# print(f"Total time:", round(time.time() - start_time, 2), "seconds.")
# print(total_f1)
# print(total_exact)
# print(total_f1/total_len)
|