File size: 6,875 Bytes
0157dfd af461f3 0157dfd 51a31d4 51dabd6 0157dfd af461f3 ab5dfc2 0157dfd b06298d af461f3 0157dfd 325e3c6 b06298d ab5dfc2 f2e3e47 0157dfd 51dabd6 51a31d4 e9df5ab 51a31d4 b06298d 0157dfd b06298d 0157dfd b06298d 0157dfd b06298d 0157dfd b06298d 0157dfd b06298d 0157dfd b06298d 0157dfd b06298d 0157dfd b06298d 0157dfd b06298d b7158e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
from dotenv import load_dotenv
# needs to happen as very first thing, otherwise HF ignores env vars
load_dotenv()
import os
import pandas as pd
from dataclasses import dataclass
from typing import Dict, cast
from datasets import DatasetDict, load_dataset
from src.readers.base_reader import Reader
from src.evaluation import evaluate
from src.readers.dpr_reader import DprReader
from src.readers.longformer_reader import LongformerReader
from src.retrievers.base_retriever import Retriever
from src.retrievers.es_retriever import ESRetriever
from src.retrievers.faiss_retriever import (
FaissRetriever,
FaissRetrieverOptions
)
from src.utils.log import logger
from src.utils.preprocessing import context_to_reader_input
from src.utils.timing import get_times, timeit
@dataclass
class Experiment:
retriever: Retriever
reader: Reader
if __name__ == '__main__':
dataset_name = "GroNLP/ik-nlp-22_slp"
paragraphs = cast(DatasetDict, load_dataset(
"GroNLP/ik-nlp-22_slp", "paragraphs"))
questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
# Only doing a few questions for speed
subset_idx = len(questions["test"])
questions_test = questions["test"][:subset_idx]
experiments: Dict[str, Experiment] = {
"faiss_dpr": Experiment(
retriever=FaissRetriever(
paragraphs,
FaissRetrieverOptions.dpr("./src/models/dpr.faiss")),
reader=DprReader()
),
"faiss_longformer": Experiment(
retriever=FaissRetriever(
paragraphs,
FaissRetrieverOptions.longformer("./src/models/longformer.faiss")),
reader=LongformerReader()
),
"es_dpr": Experiment(
retriever=ESRetriever(paragraphs),
reader=DprReader()
),
"es_longformer": Experiment(
retriever=ESRetriever(paragraphs),
reader=LongformerReader()
),
}
for experiment_name, experiment in experiments.items():
logger.info(f"Running experiment {experiment_name}...")
for idx in range(subset_idx):
question = questions_test["question"][idx]
answer = questions_test["answer"][idx]
retrieve_timer = timeit(f"{experiment_name}.retrieve")
t_retrieve = retrieve_timer(experiment.retriever.retrieve)
read_timer = timeit(f"{experiment_name}.read")
t_read = read_timer(experiment.reader.read)
print(f"\x1b[1K\r[{idx+1:03}] - \"{question}\"", end='')
scores, context = t_retrieve(question, 5)
reader_input = context_to_reader_input(context)
# workaround so we can use the decorator with a dynamic name for
# time recording
answers = t_read(question, reader_input, 5)
# Calculate softmaxed scores for readable output
# sm = torch.nn.Softmax(dim=0)
# document_scores = sm(torch.Tensor(
# [pred.relevance_score for pred in answers]))
# span_scores = sm(torch.Tensor(
# [pred.span_score for pred in answers]))
# print_answers(answers, scores, context)
# TODO evaluation and storing of results
print()
times = get_times()
df = pd.DataFrame(times)
os.makedirs("./results/", exist_ok=True)
df.to_csv("./results/timings.csv")
# TODO evaluation and storing of results
# # Initialize retriever
# retriever = FaissRetriever(paragraphs)
# # retriever = ESRetriever(paragraphs)
# # Retrieve example
# # random.seed(111)
# random_index = random.randint(0, len(questions_test["question"])-1)
# example_q = questions_test["question"][random_index]
# example_a = questions_test["answer"][random_index]
# scores, result = retriever.retrieve(example_q)
# reader_input = context_to_reader_input(result)
# # TODO: use new code from query.py to clean this up
# # Initialize reader
# answers = reader.read(example_q, reader_input)
# # Calculate softmaxed scores for readable output
# sm = torch.nn.Softmax(dim=0)
# document_scores = sm(torch.Tensor(
# [pred.relevance_score for pred in answers]))
# span_scores = sm(torch.Tensor(
# [pred.span_score for pred in answers]))
# print(example_q)
# for answer_i, answer in enumerate(answers):
# print(f"[{answer_i + 1}]: {answer.text}")
# print(f"\tDocument {answer.doc_id}", end='')
# print(f"\t(score {document_scores[answer_i] * 100:.02f})")
# print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
# print(f"\t(score {span_scores[answer_i] * 100:.02f})")
# print() # Newline
# # print(f"Example q: {example_q} answer: {result['text'][0]}")
# # for i, score in enumerate(scores):
# # print(f"Result {i+1} (score: {score:.02f}):")
# # print(result['text'][i])
# # Determine best answer we want to evaluate
# highest, highest_index = 0, 0
# for i, value in enumerate(span_scores):
# if value + document_scores[i] > highest:
# highest = value + document_scores[i]
# highest_index = i
# # Retrieve exact match and F1-score
# exact_match, f1_score = evaluate(
# example_a, answers[highest_index].text)
# print(f"Gold answer: {example_a}\n"
# f"Predicted answer: {answers[highest_index].text}\n"
# f"Exact match: {exact_match:.02f}\n"
# f"F1-score: {f1_score:.02f}")
# Calculate overall performance
# total_f1 = 0
# total_exact = 0
# total_len = len(questions_test["question"])
# start_time = time.time()
# for i, question in enumerate(questions_test["question"]):
# print(question)
# answer = questions_test["answer"][i]
# print(answer)
#
# scores, result = retriever.retrieve(question)
# reader_input = result_to_reader_input(result)
# answers = reader.read(question, reader_input)
#
# document_scores = sm(torch.Tensor(
# [pred.relevance_score for pred in answers]))
# span_scores = sm(torch.Tensor(
# [pred.span_score for pred in answers]))
#
# highest, highest_index = 0, 0
# for j, value in enumerate(span_scores):
# if value + document_scores[j] > highest:
# highest = value + document_scores[j]
# highest_index = j
# print(answers[highest_index])
# exact_match, f1_score = evaluate(answer, answers[highest_index].text)
# total_f1 += f1_score
# total_exact += exact_match
# print(f"Total time:", round(time.time() - start_time, 2), "seconds.")
# print(total_f1)
# print(total_exact)
# print(total_f1/total_len)
|