Robert commited on
Commit
ade7e41
1 Parent(s): 112c207

Evaluation uses best predicted answer now

Browse files
Files changed (2) hide show
  1. main.py +21 -7
  2. src/evaluation.py +4 -19
main.py CHANGED
@@ -3,7 +3,7 @@ from datasets import DatasetDict, load_dataset
3
  from src.readers.dpr_reader import DprReader
4
  from src.retrievers.faiss_retriever import FaissRetriever
5
  from src.utils.log import get_logger
6
- # from src.evaluation import evaluate
7
  from typing import cast
8
 
9
  from src.utils.preprocessing import result_to_reader_input
@@ -11,6 +11,7 @@ from src.utils.preprocessing import result_to_reader_input
11
  import torch
12
  import transformers
13
  import os
 
14
 
15
  os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
16
 
@@ -30,9 +31,12 @@ if __name__ == '__main__':
30
  retriever = FaissRetriever()
31
 
32
  # Retrieve example
33
- example_q = questions_test.shuffle()["question"][0]
34
- scores, result = retriever.retrieve(example_q)
 
 
35
 
 
36
  reader_input = result_to_reader_input(result)
37
 
38
  # Initialize reader
@@ -61,7 +65,17 @@ if __name__ == '__main__':
61
  # print(f"Result {i+1} (score: {score:.02f}):")
62
  # print(result['text'][i])
63
 
64
- # # Compute overall performance
65
- # exact_match, f1_score = evaluate(
66
- # r, questions_test["question"], questions_test["answer"])
67
- # print(f"Exact match: {exact_match:.02f}\n", f"F1-score: {f1_score:.02f}")
 
 
 
 
 
 
 
 
 
 
 
3
  from src.readers.dpr_reader import DprReader
4
  from src.retrievers.faiss_retriever import FaissRetriever
5
  from src.utils.log import get_logger
6
+ from src.evaluation import evaluate
7
  from typing import cast
8
 
9
  from src.utils.preprocessing import result_to_reader_input
 
11
  import torch
12
  import transformers
13
  import os
14
+ import random
15
 
16
  os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
17
 
 
31
  retriever = FaissRetriever()
32
 
33
  # Retrieve example
34
+ #random.seed(111)
35
+ random_index = random.randint(0, len(questions_test["question"])-1)
36
+ example_q = questions_test["question"][random_index]
37
+ example_a = questions_test["answer"][random_index]
38
 
39
+ scores, result = retriever.retrieve(example_q)
40
  reader_input = result_to_reader_input(result)
41
 
42
  # Initialize reader
 
65
  # print(f"Result {i+1} (score: {score:.02f}):")
66
  # print(result['text'][i])
67
 
68
+ # Determine best answer we want to evaluate
69
+ highest, highest_index = 0, 0
70
+ for i, value in enumerate(span_scores):
71
+ if value + document_scores[i] > highest:
72
+ highest = value + document_scores[i]
73
+ highest_index = i
74
+
75
+ # Retrieve exact match and F1-score
76
+ exact_match, f1_score = evaluate(
77
+ example_a, answers[highest_index].text)
78
+ print(f"Gold answer: {example_a}\n"
79
+ f"Predicted answer: {answers[highest_index].text}\n"
80
+ f"Exact match: {exact_match:.02f}\n"
81
+ f"F1-score: {f1_score:.02f}")
src/evaluation.py CHANGED
@@ -66,27 +66,12 @@ def f1(prediction: str, answer: str) -> float:
66
  return 2 * (prec * rec) / (prec + rec)
67
 
68
 
69
- def evaluate(retriever: Retriever, questions: Any, answers: Any):
70
- """Evaluates the entire model by computing F1-score and exact match on the
71
- entire dataset.
72
 
73
  Returns:
74
  float: overall exact match
75
  float: overall F1-score
76
  """
77
-
78
- predictions = []
79
- scores = 0
80
-
81
- # Currently just takes the first answer and does not look at scores yet
82
- for question in questions:
83
- score, result = retriever.retrieve(question, 1)
84
- scores += score[0]
85
- predictions.append(result['text'][0])
86
-
87
- exact_matches = [exact_match(
88
- predictions[i], answers[i]) for i in range(len(answers))]
89
- f1_scores = [f1(
90
- predictions[i], answers[i]) for i in range(len(answers))]
91
-
92
- return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
 
66
  return 2 * (prec * rec) / (prec + rec)
67
 
68
 
69
+ def evaluate(answer: Any, prediction: Any):
70
+ """Evaluates the model by computing F1-score and exact match of the best
71
+ predicted answer on a random sentence.
72
 
73
  Returns:
74
  float: overall exact match
75
  float: overall F1-score
76
  """
77
+ return exact_match(prediction, answer), f1(prediction, answer)