Ramon Meffert commited on
Commit
1f08ed2
1 Parent(s): a1746cf

Add query cli w/ argparse

Browse files
Files changed (1) hide show
  1. query.py +84 -0
query.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import transformers
4
+
5
+ from typing import List
6
+ from datasets import load_dataset, DatasetDict
7
+ from dotenv import load_dotenv
8
+
9
+ from src.readers.dpr_reader import DprReader
10
+ from src.retrievers.base_retriever import Retriever
11
+ from src.retrievers.es_retriever import ESRetriever
12
+ from src.retrievers.faiss_retriever import FaissRetriever
13
+ from src.utils.preprocessing import result_to_reader_input
14
+ from src.utils.log import get_logger
15
+
16
+
17
+ def get_retriever(r: str, ds: DatasetDict) -> Retriever:
18
+ retriever = ESRetriever if r == "es" else FaissRetriever
19
+ return retriever(ds)
20
+
21
+
22
+ def print_name(contexts: dict, section: str, id: int):
23
+ name = contexts[section][id]
24
+ if name != 'nan':
25
+ print(f" {section}: {name}")
26
+
27
+
28
+ def print_answers(answers: List[tuple], scores: List[float], contexts: dict):
29
+ # calculate answer scores
30
+ sm = torch.nn.Softmax(dim=0)
31
+ d_scores = sm(torch.Tensor(
32
+ [pred.relevance_score for pred in answers]))
33
+ s_scores = sm(torch.Tensor(
34
+ [pred.span_score for pred in answers]))
35
+
36
+ for pos, answer in enumerate(answers):
37
+ print(f"{pos + 1:>4}. {answer.text}")
38
+ print(f" {'-' * len(answer.text)}")
39
+ print_name(contexts, 'chapter', answer.doc_id)
40
+ print_name(contexts, 'section', answer.doc_id)
41
+ print_name(contexts, 'subsection', answer.doc_id)
42
+ print(f" retrieval score: {scores[answer.doc_id]:6.02f}%")
43
+ print(f" document score: {d_scores[pos] * 100:6.02f}%")
44
+ print(f" span score: {s_scores[pos] * 100:6.02f}%")
45
+ print()
46
+
47
+
48
+ def main(args: argparse.Namespace):
49
+ # Initialize dataset
50
+ dataset = load_dataset("GroNLP/ik-nlp-22_slp")
51
+
52
+ # Retrieve
53
+ retriever = get_retriever(args.retriever, dataset)
54
+ scores, contexts = retriever.retrieve(args.query)
55
+
56
+ # Read
57
+ reader = DprReader()
58
+ reader_input = result_to_reader_input(contexts)
59
+ answers = reader.read(args.query, reader_input, num_answers=args.top)
60
+
61
+ # Print output
62
+ print_answers(answers, scores, contexts)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ # Setup environment
67
+ load_dotenv()
68
+ logger = get_logger()
69
+ transformers.logging.set_verbosity_error()
70
+
71
+ # Set up CLI arguments
72
+ parser = argparse.ArgumentParser(
73
+ formatter_class=argparse.MetavarTypeHelpFormatter
74
+ )
75
+ parser.add_argument("query", type=str,
76
+ help="The question to feed to the QA system")
77
+ parser.add_argument("--top", "-t", type=int, default=1,
78
+ help="The number of answers to retrieve")
79
+ parser.add_argument("--retriever", "-r", type=str.lower,
80
+ choices=["faiss", "es"], default="faiss",
81
+ help="The retrieval method to use")
82
+
83
+ args = parser.parse_args()
84
+ main(args)