GGroenendaal commited on
Commit
51a31d4
β€’
1 Parent(s): 51dabd6

refactor evaluation

Browse files
main.py CHANGED
@@ -1,23 +1,38 @@
1
- from src.fais_retriever import FAISRetriever
2
- from src.utils.log import get_logger
3
 
 
 
 
 
4
 
5
  logger = get_logger()
6
 
7
 
8
  if __name__ == '__main__':
 
 
 
 
 
 
 
 
9
  # Initialize retriever
10
  r = FAISRetriever()
11
 
12
- # Retrieve example
13
- scores, result = r.retrieve(
14
- "What is the perplexity of a language model?")
 
 
 
15
 
16
  for i, score in enumerate(scores):
17
  logger.info(f"Result {i+1} (score: {score:.02f}):")
18
  logger.info(result['text'][i])
19
 
20
  # Compute overall performance
21
- exact_match, f1_score = r.evaluate()
 
22
  logger.info(f"Exact match: {exact_match:.02f}\n"
23
  f"F1-score: {f1_score:.02f}")
 
1
+ from datasets import DatasetDict, load_dataset
 
2
 
3
+ from src.retrievers.fais_retriever import FAISRetriever
4
+ from src.utils.log import get_logger
5
+ from src.evaluation import evaluate
6
+ from typing import cast
7
 
8
  logger = get_logger()
9
 
10
 
11
  if __name__ == '__main__':
12
+ dataset_name = "GroNLP/ik-nlp-22_slp"
13
+ paragraphs = load_dataset(dataset_name, "paragraphs")
14
+ questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
15
+
16
+ questions_test = questions["test"]
17
+
18
+ logger.info(questions)
19
+
20
  # Initialize retriever
21
  r = FAISRetriever()
22
 
23
+ # # Retrieve example
24
+ example_q = "What is the perplexity of a language model?"
25
+ scores, result = r.retrieve(example_q)
26
+
27
+ logger.info(
28
+ f"Example q: {example_q} answer: {result['text'][0]}")
29
 
30
  for i, score in enumerate(scores):
31
  logger.info(f"Result {i+1} (score: {score:.02f}):")
32
  logger.info(result['text'][i])
33
 
34
  # Compute overall performance
35
+ exact_match, f1_score = evaluate(
36
+ r, questions_test["question"], questions_test["answer"])
37
  logger.info(f"Exact match: {exact_match:.02f}\n"
38
  f"F1-score: {f1_score:.02f}")
src/evaluation.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import Callable, List
 
2
 
3
  from src.utils.string_utils import (lower, remove_articles, remove_punc,
4
  white_space_fix)
@@ -63,3 +64,29 @@ def f1(prediction: str, answer: str) -> float:
63
  rec = len(common_tokens) / len(answer_tokens)
64
 
65
  return 2 * (prec * rec) / (prec + rec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, List
2
+ from src.retrievers.base_retriever import Retriever
3
 
4
  from src.utils.string_utils import (lower, remove_articles, remove_punc,
5
  white_space_fix)
 
64
  rec = len(common_tokens) / len(answer_tokens)
65
 
66
  return 2 * (prec * rec) / (prec + rec)
67
+
68
+
69
+ def evaluate(retriever: Retriever, questions: Any, answers: Any):
70
+ """Evaluates the entire model by computing F1-score and exact match on the
71
+ entire dataset.
72
+
73
+ Returns:
74
+ float: overall exact match
75
+ float: overall F1-score
76
+ """
77
+
78
+ predictions = []
79
+ scores = 0
80
+
81
+ # Currently just takes the first answer and does not look at scores yet
82
+ for question in questions:
83
+ score, result = retriever.retrieve(question, 1)
84
+ scores += score[0]
85
+ predictions.append(result['text'][0])
86
+
87
+ exact_matches = [exact_match(
88
+ predictions[i], answers[i]) for i in range(len(answers))]
89
+ f1_scores = [f1(
90
+ predictions[i], answers[i]) for i in range(len(answers))]
91
+
92
+ return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
src/retrievers/base_retriever.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ class Retriever():
2
+ def retrieve(self, query: str, k: int):
3
+ pass
src/{es_retriever.py β†’ retrievers/es_retriever.py} RENAMED
@@ -1,8 +1,10 @@
1
- class ESRetriever:
2
- def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp"):
3
- self.dataset_name = dataset_name
4
 
5
- def _setup_data(self):
 
 
 
 
6
  pass
7
 
8
  def retrieve(self, query: str, k: int):
 
1
+ from src.utils.log import get_logger
 
 
2
 
3
+ logger = get_logger()
4
+
5
+
6
+ class ESRetriever(Retriever):
7
+ def __init__(self, data_set):
8
  pass
9
 
10
  def retrieve(self, query: str, k: int):
src/{fais_retriever.py β†’ retrievers/fais_retriever.py} RENAMED
@@ -1,19 +1,27 @@
1
- # Hacky fix for FAISS error on macOS
2
- # See https://stackoverflow.com/a/63374568/4545692
3
  import os
4
  import os.path
5
 
6
  import torch
7
  from datasets import load_dataset
8
- from transformers import (DPRContextEncoder, DPRContextEncoderTokenizer,
9
- DPRQuestionEncoder, DPRQuestionEncoderTokenizer)
 
 
 
 
10
 
11
- from src.evaluation import exact_match, f1
 
12
 
13
  os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
 
 
 
14
 
 
15
 
16
- class FAISRetriever:
 
17
  """A class used to retrieve relevant documents based on some query.
18
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
19
  """
@@ -65,7 +73,7 @@ class FAISRetriever:
65
  # Load dataset
66
  ds = load_dataset(dataset_name, name="paragraphs")[
67
  "train"] # type: ignore
68
- print(ds)
69
 
70
  if os.path.exists(embedding_path):
71
  # If we already have FAISS embeddings, load them from disk
@@ -115,32 +123,3 @@ class FAISRetriever:
115
  )
116
 
117
  return scores, results
118
-
119
- def evaluate(self):
120
- """Evaluates the entire model by computing F1-score and exact match on the
121
- entire dataset.
122
-
123
- Returns:
124
- float: overall exact match
125
- float: overall F1-score
126
- """
127
- questions_ds = load_dataset(
128
- self.dataset_name, name="questions")['test']
129
- questions = questions_ds['question']
130
- answers = questions_ds['answer']
131
-
132
- predictions = []
133
- scores = 0
134
-
135
- # Currently just takes the first answer and does not look at scores yet
136
- for question in questions:
137
- score, result = self.retrieve(question, 1)
138
- scores += score[0]
139
- predictions.append(result['text'][0])
140
-
141
- exact_matches = [exact_match(
142
- predictions[i], answers[i]) for i in range(len(answers))]
143
- f1_scores = [f1(
144
- predictions[i], answers[i]) for i in range(len(answers))]
145
-
146
- return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
 
 
 
1
  import os
2
  import os.path
3
 
4
  import torch
5
  from datasets import load_dataset
6
+ from transformers import (
7
+ DPRContextEncoder,
8
+ DPRContextEncoderTokenizer,
9
+ DPRQuestionEncoder,
10
+ DPRQuestionEncoderTokenizer,
11
+ )
12
 
13
+ from src.retrievers.base_retriever import Retriever
14
+ from src.utils.log import get_logger
15
 
16
  os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
17
+ # Hacky fix for FAISS error on macOS
18
+ # See https://stackoverflow.com/a/63374568/4545692
19
+
20
 
21
+ logger = get_logger()
22
 
23
+
24
+ class FAISRetriever(Retriever):
25
  """A class used to retrieve relevant documents based on some query.
26
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
27
  """
 
73
  # Load dataset
74
  ds = load_dataset(dataset_name, name="paragraphs")[
75
  "train"] # type: ignore
76
+ logger.info(ds)
77
 
78
  if os.path.exists(embedding_path):
79
  # If we already have FAISS embeddings, load them from disk
 
123
  )
124
 
125
  return scores, results