File size: 2,816 Bytes
51a31d4 51dabd6 ab5dfc2 51a31d4 ade7e41 51a31d4 51dabd6 ab5dfc2 ade7e41 51dabd6 ab5dfc2 51dabd6 51a31d4 ab5dfc2 51a31d4 1fb8ae3 51dabd6 1fb8ae3 ab5dfc2 1fb8ae3 ade7e41 ab5dfc2 ade7e41 ab5dfc2 51dabd6 ab5dfc2 51a31d4 ab5dfc2 51dabd6 ab5dfc2 51dabd6 ade7e41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from datasets import DatasetDict, load_dataset
from src.readers.dpr_reader import DprReader
from src.retrievers.faiss_retriever import FaissRetriever
from src.utils.log import get_logger
from src.evaluation import evaluate
from typing import cast
from src.utils.preprocessing import result_to_reader_input
import torch
import transformers
import os
import random
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
logger = get_logger()
transformers.logging.set_verbosity_error()
if __name__ == '__main__':
dataset_name = "GroNLP/ik-nlp-22_slp"
paragraphs = load_dataset(dataset_name, "paragraphs")
questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
questions_test = questions["test"]
# logger.info(questions)
dataset_paragraphs = cast(DatasetDict, load_dataset(
"GroNLP/ik-nlp-22_slp", "paragraphs"))
# Initialize retriever
retriever = FaissRetriever(dataset_paragraphs)
# Retrieve example
# random.seed(111)
random_index = random.randint(0, len(questions_test["question"])-1)
example_q = questions_test["question"][random_index]
example_a = questions_test["answer"][random_index]
scores, result = retriever.retrieve(example_q)
reader_input = result_to_reader_input(result)
# Initialize reader
reader = DprReader()
answers = reader.read(example_q, reader_input)
# Calculate softmaxed scores for readable output
sm = torch.nn.Softmax(dim=0)
document_scores = sm(torch.Tensor(
[pred.relevance_score for pred in answers]))
span_scores = sm(torch.Tensor(
[pred.span_score for pred in answers]))
print(example_q)
for answer_i, answer in enumerate(answers):
print(f"[{answer_i + 1}]: {answer.text}")
print(f"\tDocument {answer.doc_id}", end='')
print(f"\t(score {document_scores[answer_i] * 100:.02f})")
print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
print(f"\t(score {span_scores[answer_i] * 100:.02f})")
print() # Newline
# print(f"Example q: {example_q} answer: {result['text'][0]}")
# for i, score in enumerate(scores):
# print(f"Result {i+1} (score: {score:.02f}):")
# print(result['text'][i])
# Determine best answer we want to evaluate
highest, highest_index = 0, 0
for i, value in enumerate(span_scores):
if value + document_scores[i] > highest:
highest = value + document_scores[i]
highest_index = i
# Retrieve exact match and F1-score
exact_match, f1_score = evaluate(
example_a, answers[highest_index].text)
print(f"Gold answer: {example_a}\n"
f"Predicted answer: {answers[highest_index].text}\n"
f"Exact match: {exact_match:.02f}\n"
f"F1-score: {f1_score:.02f}")
|