from datasets import DatasetDict, load_dataset from src.readers.dpr_reader import DprReader from src.retrievers.faiss_retriever import FaissRetriever from src.utils.log import get_logger # from src.evaluation import evaluate from typing import cast from src.utils.preprocessing import result_to_reader_input import torch import transformers import os os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' logger = get_logger() transformers.logging.set_verbosity_error() if __name__ == '__main__': dataset_name = "GroNLP/ik-nlp-22_slp" paragraphs = load_dataset(dataset_name, "paragraphs") questions = cast(DatasetDict, load_dataset(dataset_name, "questions")) questions_test = questions["test"] # logger.info(questions) # Initialize retriever retriever = FaissRetriever() # Retrieve example example_q = questions_test.shuffle()["question"][0] scores, result = retriever.retrieve(example_q) reader_input = result_to_reader_input(result) # Initialize reader reader = DprReader() answers = reader.read(example_q, reader_input) # Calculate softmaxed scores for readable output sm = torch.nn.Softmax(dim=0) document_scores = sm(torch.Tensor( [pred.relevance_score for pred in answers])) span_scores = sm(torch.Tensor( [pred.span_score for pred in answers])) print(example_q) for answer_i, answer in enumerate(answers): print(f"[{answer_i + 1}]: {answer.text}") print(f"\tDocument {answer.doc_id}", end='') print(f"\t(score {document_scores[answer_i] * 100:.02f})") print(f"\tSpan {answer.start_index}-{answer.end_index}", end='') print(f"\t(score {span_scores[answer_i] * 100:.02f})") print() # Newline # print(f"Example q: {example_q} answer: {result['text'][0]}") # for i, score in enumerate(scores): # print(f"Result {i+1} (score: {score:.02f}):") # print(result['text'][i]) # # Compute overall performance # exact_match, f1_score = evaluate( # r, questions_test["question"], questions_test["answer"]) # print(f"Exact match: {exact_match:.02f}\n", f"F1-score: {f1_score:.02f}")