Ramon Meffert
commited on
Commit
•
07cae66
1
Parent(s):
492106d
Remove old code
Browse files
main.py
CHANGED
@@ -125,92 +125,3 @@ if __name__ == '__main__':
|
|
125 |
os.makedirs("./results/", exist_ok=True)
|
126 |
f1_results.to_csv("./results/f1_scores.csv")
|
127 |
em_results.to_csv("./results/em_scores.csv")
|
128 |
-
|
129 |
-
# TODO evaluation and storing of results
|
130 |
-
|
131 |
-
# # Initialize retriever
|
132 |
-
# retriever = FaissRetriever(paragraphs)
|
133 |
-
# # retriever = ESRetriever(paragraphs)
|
134 |
-
|
135 |
-
# # Retrieve example
|
136 |
-
# # random.seed(111)
|
137 |
-
# random_index = random.randint(0, len(questions_test["question"])-1)
|
138 |
-
# example_q = questions_test["question"][random_index]
|
139 |
-
# example_a = questions_test["answer"][random_index]
|
140 |
-
|
141 |
-
# scores, result = retriever.retrieve(example_q)
|
142 |
-
# reader_input = context_to_reader_input(result)
|
143 |
-
|
144 |
-
# # TODO: use new code from query.py to clean this up
|
145 |
-
# # Initialize reader
|
146 |
-
# answers = reader.read(example_q, reader_input)
|
147 |
-
|
148 |
-
# # Calculate softmaxed scores for readable output
|
149 |
-
# sm = torch.nn.Softmax(dim=0)
|
150 |
-
# document_scores = sm(torch.Tensor(
|
151 |
-
# [pred.relevance_score for pred in answers]))
|
152 |
-
# span_scores = sm(torch.Tensor(
|
153 |
-
# [pred.span_score for pred in answers]))
|
154 |
-
|
155 |
-
# print(example_q)
|
156 |
-
# for answer_i, answer in enumerate(answers):
|
157 |
-
# print(f"[{answer_i + 1}]: {answer.text}")
|
158 |
-
# print(f"\tDocument {answer.doc_id}", end='')
|
159 |
-
# print(f"\t(score {document_scores[answer_i] * 100:.02f})")
|
160 |
-
# print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
|
161 |
-
# print(f"\t(score {span_scores[answer_i] * 100:.02f})")
|
162 |
-
# print() # Newline
|
163 |
-
|
164 |
-
# # print(f"Example q: {example_q} answer: {result['text'][0]}")
|
165 |
-
|
166 |
-
# # for i, score in enumerate(scores):
|
167 |
-
# # print(f"Result {i+1} (score: {score:.02f}):")
|
168 |
-
# # print(result['text'][i])
|
169 |
-
|
170 |
-
# # Determine best answer we want to evaluate
|
171 |
-
# highest, highest_index = 0, 0
|
172 |
-
# for i, value in enumerate(span_scores):
|
173 |
-
# if value + document_scores[i] > highest:
|
174 |
-
# highest = value + document_scores[i]
|
175 |
-
# highest_index = i
|
176 |
-
|
177 |
-
# # Retrieve exact match and F1-score
|
178 |
-
# exact_match, f1_score = evaluate(
|
179 |
-
# example_a, answers[highest_index].text)
|
180 |
-
# print(f"Gold answer: {example_a}\n"
|
181 |
-
# f"Predicted answer: {answers[highest_index].text}\n"
|
182 |
-
# f"Exact match: {exact_match:.02f}\n"
|
183 |
-
# f"F1-score: {f1_score:.02f}")
|
184 |
-
|
185 |
-
# Calculate overall performance
|
186 |
-
# total_f1 = 0
|
187 |
-
# total_exact = 0
|
188 |
-
# total_len = len(questions_test["question"])
|
189 |
-
# start_time = time.time()
|
190 |
-
# for i, question in enumerate(questions_test["question"]):
|
191 |
-
# print(question)
|
192 |
-
# answer = questions_test["answer"][i]
|
193 |
-
# print(answer)
|
194 |
-
#
|
195 |
-
# scores, result = retriever.retrieve(question)
|
196 |
-
# reader_input = result_to_reader_input(result)
|
197 |
-
# answers = reader.read(question, reader_input)
|
198 |
-
#
|
199 |
-
# document_scores = sm(torch.Tensor(
|
200 |
-
# [pred.relevance_score for pred in answers]))
|
201 |
-
# span_scores = sm(torch.Tensor(
|
202 |
-
# [pred.span_score for pred in answers]))
|
203 |
-
#
|
204 |
-
# highest, highest_index = 0, 0
|
205 |
-
# for j, value in enumerate(span_scores):
|
206 |
-
# if value + document_scores[j] > highest:
|
207 |
-
# highest = value + document_scores[j]
|
208 |
-
# highest_index = j
|
209 |
-
# print(answers[highest_index])
|
210 |
-
# exact_match, f1_score = evaluate(answer, answers[highest_index].text)
|
211 |
-
# total_f1 += f1_score
|
212 |
-
# total_exact += exact_match
|
213 |
-
# print(f"Total time:", round(time.time() - start_time, 2), "seconds.")
|
214 |
-
# print(total_f1)
|
215 |
-
# print(total_exact)
|
216 |
-
# print(total_f1/total_len)
|
|
|
125 |
os.makedirs("./results/", exist_ok=True)
|
126 |
f1_results.to_csv("./results/f1_scores.csv")
|
127 |
em_results.to_csv("./results/em_scores.csv")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|