File size: 4,939 Bytes
1f08ed2 be1f224 1f08ed2 be1f224 1f08ed2 be1f224 325e3c6 0157dfd 1f08ed2 be1f224 1f08ed2 e9df5ab 1f08ed2 325e3c6 e9df5ab 325e3c6 1f08ed2 be1f224 325e3c6 e9df5ab be1f224 325e3c6 1f08ed2 e9df5ab 1f08ed2 be1f224 325e3c6 be1f224 1f08ed2 be1f224 1f08ed2 be1f224 1f08ed2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import argparse
import torch
import transformers
from typing import Dict, List, Literal, Tuple, cast
from datasets import load_dataset, DatasetDict
from dotenv import load_dotenv
from src.readers.base_reader import Reader
from src.readers.longformer_reader import LongformerReader
from src.readers.dpr_reader import DprReader
from src.retrievers.base_retriever import Retriever
from src.retrievers.es_retriever import ESRetriever
from src.retrievers.faiss_retriever import (
FaissRetriever,
FaissRetrieverOptions
)
from src.utils.preprocessing import context_to_reader_input
from src.utils.log import logger
# Setup environment
load_dotenv()
transformers.logging.set_verbosity_error()
def get_retriever(paragraphs: DatasetDict,
r: Literal["es", "faiss"],
lm: Literal["dpr", "longformer"]) -> Retriever:
match (r, lm):
case "es", _:
return ESRetriever()
case "faiss", "dpr":
options = FaissRetrieverOptions.dpr("./src/models/dpr.faiss")
return FaissRetriever(paragraphs, options)
case "faiss", "longformer":
options = FaissRetrieverOptions.longformer(
"./src/models/longformer.faiss")
return FaissRetriever(paragraphs, options)
case _:
raise ValueError("Retriever options not recognized")
def get_reader(lm: Literal["dpr", "longformer"]) -> Reader:
match lm:
case "dpr":
return DprReader()
case "longformer":
return LongformerReader()
case _:
raise ValueError("Language model not recognized")
def print_name(contexts: dict, section: str, id: int):
name = contexts[section][id]
if name != 'nan':
print(f" {section}: {name}")
def get_retrieval_span_scores(answers: List[tuple]):
# calculate answer scores
sm = torch.nn.Softmax(dim=0)
d_scores = sm(torch.Tensor(
[pred.relevance_score for pred in answers]))
s_scores = sm(torch.Tensor(
[pred.span_score for pred in answers]))
return d_scores, s_scores
def print_answers(answers: List[tuple], scores: List[float], contexts: dict):
d_scores, s_scores = get_retrieval_span_scores(answers)
for pos, answer in enumerate(answers):
print(f"{pos + 1:>4}. {answer.text}")
print(f" {'-' * len(answer.text)}")
print_name(contexts, 'chapter', answer.doc_id)
print_name(contexts, 'section', answer.doc_id)
print_name(contexts, 'subsection', answer.doc_id)
print(f" retrieval score: {scores[answer.doc_id]:6.02f}%")
print(f" document score: {d_scores[pos] * 100:6.02f}%")
print(f" span score: {s_scores[pos] * 100:6.02f}%")
print()
def probe(query: str,
retriever: Retriever,
reader: Reader,
num_answers: int = 5) \
-> Tuple[List[tuple], List[float], Dict[str, List[str]]]:
scores, contexts = retriever.retrieve(query)
reader_input = context_to_reader_input(contexts)
answers = reader.read(query, reader_input, num_answers)
return answers, scores, contexts
def default_probe(query: str):
# default probe is a probe that prints 5 answers with faiss
paragraphs = cast(DatasetDict, load_dataset(
"GroNLP/ik-nlp-22_slp", "paragraphs"))
retriever = get_retriever(paragraphs, "faiss", "dpr")
reader = DprReader()
return probe(query, retriever, reader)
def main(args: argparse.Namespace):
# Initialize dataset
paragraphs = cast(DatasetDict, load_dataset(
"GroNLP/ik-nlp-22_slp", "paragraphs"))
# Retrieve
retriever = get_retriever(paragraphs, args.retriever, args.lm)
reader = get_reader(args.lm)
answers, scores, contexts = probe(
args.query, retriever, reader, args.top)
# Print output
print("Question: " + args.query)
print("Answer(s):")
if args.lm == "dpr":
print_answers(answers, scores, contexts)
else:
answers = filter(lambda a: len(a[0].strip()) > 0, answers)
for pos, answer in enumerate(answers, start=1):
print(f" - {answer[0].strip()}")
if __name__ == "__main__":
# Set up CLI arguments
parser = argparse.ArgumentParser(
formatter_class=argparse.MetavarTypeHelpFormatter
)
parser.add_argument(
"query", type=str, help="The question to feed to the QA system")
parser.add_argument(
"--top", "-t", type=int, default=1,
help="The number of answers to retrieve")
parser.add_argument(
"--retriever", "-r", type=str.lower, choices=["faiss", "es"],
default="faiss", help="The retrieval method to use")
parser.add_argument(
"--lm", "-l", type=str.lower,
choices=["dpr", "longformer"], default="dpr",
help="The language model to use for the FAISS retriever")
args = parser.parse_args()
main(args)
|