Robert
commited on
Commit
•
ade7e41
1
Parent(s):
112c207
Evaluation uses best predicted answer now
Browse files- main.py +21 -7
- src/evaluation.py +4 -19
main.py
CHANGED
@@ -3,7 +3,7 @@ from datasets import DatasetDict, load_dataset
|
|
3 |
from src.readers.dpr_reader import DprReader
|
4 |
from src.retrievers.faiss_retriever import FaissRetriever
|
5 |
from src.utils.log import get_logger
|
6 |
-
|
7 |
from typing import cast
|
8 |
|
9 |
from src.utils.preprocessing import result_to_reader_input
|
@@ -11,6 +11,7 @@ from src.utils.preprocessing import result_to_reader_input
|
|
11 |
import torch
|
12 |
import transformers
|
13 |
import os
|
|
|
14 |
|
15 |
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
|
16 |
|
@@ -30,9 +31,12 @@ if __name__ == '__main__':
|
|
30 |
retriever = FaissRetriever()
|
31 |
|
32 |
# Retrieve example
|
33 |
-
|
34 |
-
|
|
|
|
|
35 |
|
|
|
36 |
reader_input = result_to_reader_input(result)
|
37 |
|
38 |
# Initialize reader
|
@@ -61,7 +65,17 @@ if __name__ == '__main__':
|
|
61 |
# print(f"Result {i+1} (score: {score:.02f}):")
|
62 |
# print(result['text'][i])
|
63 |
|
64 |
-
#
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from src.readers.dpr_reader import DprReader
|
4 |
from src.retrievers.faiss_retriever import FaissRetriever
|
5 |
from src.utils.log import get_logger
|
6 |
+
from src.evaluation import evaluate
|
7 |
from typing import cast
|
8 |
|
9 |
from src.utils.preprocessing import result_to_reader_input
|
|
|
11 |
import torch
|
12 |
import transformers
|
13 |
import os
|
14 |
+
import random
|
15 |
|
16 |
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
|
17 |
|
|
|
31 |
retriever = FaissRetriever()
|
32 |
|
33 |
# Retrieve example
|
34 |
+
#random.seed(111)
|
35 |
+
random_index = random.randint(0, len(questions_test["question"])-1)
|
36 |
+
example_q = questions_test["question"][random_index]
|
37 |
+
example_a = questions_test["answer"][random_index]
|
38 |
|
39 |
+
scores, result = retriever.retrieve(example_q)
|
40 |
reader_input = result_to_reader_input(result)
|
41 |
|
42 |
# Initialize reader
|
|
|
65 |
# print(f"Result {i+1} (score: {score:.02f}):")
|
66 |
# print(result['text'][i])
|
67 |
|
68 |
+
# Determine best answer we want to evaluate
|
69 |
+
highest, highest_index = 0, 0
|
70 |
+
for i, value in enumerate(span_scores):
|
71 |
+
if value + document_scores[i] > highest:
|
72 |
+
highest = value + document_scores[i]
|
73 |
+
highest_index = i
|
74 |
+
|
75 |
+
# Retrieve exact match and F1-score
|
76 |
+
exact_match, f1_score = evaluate(
|
77 |
+
example_a, answers[highest_index].text)
|
78 |
+
print(f"Gold answer: {example_a}\n"
|
79 |
+
f"Predicted answer: {answers[highest_index].text}\n"
|
80 |
+
f"Exact match: {exact_match:.02f}\n"
|
81 |
+
f"F1-score: {f1_score:.02f}")
|
src/evaluation.py
CHANGED
@@ -66,27 +66,12 @@ def f1(prediction: str, answer: str) -> float:
|
|
66 |
return 2 * (prec * rec) / (prec + rec)
|
67 |
|
68 |
|
69 |
-
def evaluate(
|
70 |
-
"""Evaluates the
|
71 |
-
|
72 |
|
73 |
Returns:
|
74 |
float: overall exact match
|
75 |
float: overall F1-score
|
76 |
"""
|
77 |
-
|
78 |
-
predictions = []
|
79 |
-
scores = 0
|
80 |
-
|
81 |
-
# Currently just takes the first answer and does not look at scores yet
|
82 |
-
for question in questions:
|
83 |
-
score, result = retriever.retrieve(question, 1)
|
84 |
-
scores += score[0]
|
85 |
-
predictions.append(result['text'][0])
|
86 |
-
|
87 |
-
exact_matches = [exact_match(
|
88 |
-
predictions[i], answers[i]) for i in range(len(answers))]
|
89 |
-
f1_scores = [f1(
|
90 |
-
predictions[i], answers[i]) for i in range(len(answers))]
|
91 |
-
|
92 |
-
return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
|
|
|
66 |
return 2 * (prec * rec) / (prec + rec)
|
67 |
|
68 |
|
69 |
+
def evaluate(answer: Any, prediction: Any):
|
70 |
+
"""Evaluates the model by computing F1-score and exact match of the best
|
71 |
+
predicted answer on a random sentence.
|
72 |
|
73 |
Returns:
|
74 |
float: overall exact match
|
75 |
float: overall F1-score
|
76 |
"""
|
77 |
+
return exact_match(prediction, answer), f1(prediction, answer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|