JohnnyBoy00
commited on
Commit
·
1b7c795
1
Parent(s):
60d919b
Upload evaluation.py
Browse files- evaluation.py +173 -0
evaluation.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from evaluate import load as load_metric
|
5 |
+
|
6 |
+
from sklearn.metrics import accuracy_score, f1_score
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
|
9 |
+
MAX_TARGET_LENGTH = 128
|
10 |
+
|
11 |
+
# load evaluation metrics
|
12 |
+
sacrebleu = load_metric('sacrebleu')
|
13 |
+
rouge = load_metric('rouge')
|
14 |
+
meteor = load_metric('meteor')
|
15 |
+
bertscore = load_metric('bertscore')
|
16 |
+
|
17 |
+
# use gpu if it's available
|
18 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
19 |
+
|
20 |
+
def flatten_list(l):
|
21 |
+
"""
|
22 |
+
Utility function to convert a list of lists into a flattened list
|
23 |
+
|
24 |
+
Params:
|
25 |
+
l (list of lists): list to be flattened
|
26 |
+
Returns:
|
27 |
+
A flattened list with the elements of the original list
|
28 |
+
"""
|
29 |
+
return [item for sublist in l for item in sublist]
|
30 |
+
|
31 |
+
def extract_feedback(predictions):
|
32 |
+
"""
|
33 |
+
Utility function to extract the feedback from the predictions of the model
|
34 |
+
|
35 |
+
Params:
|
36 |
+
predictions (list): complete model predictions
|
37 |
+
Returns:
|
38 |
+
feedback (list): extracted feedback from the model's predictions
|
39 |
+
"""
|
40 |
+
feedback = []
|
41 |
+
# iterate through predictions and try to extract predicted feedback
|
42 |
+
for pred in predictions:
|
43 |
+
try:
|
44 |
+
fb = pred.split(':', 1)[1]
|
45 |
+
except IndexError:
|
46 |
+
try:
|
47 |
+
if pred.lower().startswith('partially correct'):
|
48 |
+
fb = pred.split(' ', 1)[2]
|
49 |
+
else:
|
50 |
+
fb = pred.split(' ', 1)[1]
|
51 |
+
except IndexError:
|
52 |
+
fb = pred
|
53 |
+
feedback.append(fb.strip())
|
54 |
+
|
55 |
+
return feedback
|
56 |
+
|
57 |
+
def extract_labels(predictions):
|
58 |
+
"""
|
59 |
+
Utility function to extract the labels from the predictions of the model
|
60 |
+
|
61 |
+
Params:
|
62 |
+
predictions (list): complete model predictions
|
63 |
+
Returns:
|
64 |
+
feedback (list): extracted labels from the model's predictions
|
65 |
+
"""
|
66 |
+
labels = []
|
67 |
+
for pred in predictions:
|
68 |
+
if pred.lower().startswith('correct'):
|
69 |
+
label = 'Correct'
|
70 |
+
elif pred.lower().startswith('partially correct'):
|
71 |
+
label = 'Partially correct'
|
72 |
+
elif pred.lower().startswith('incorrect'):
|
73 |
+
label = 'Incorrect'
|
74 |
+
else:
|
75 |
+
label = 'Unknown label'
|
76 |
+
labels.append(label)
|
77 |
+
|
78 |
+
return labels
|
79 |
+
|
80 |
+
def compute_metrics(predictions, labels):
|
81 |
+
"""
|
82 |
+
Compute evaluation metrics from the predictions of the model
|
83 |
+
|
84 |
+
Params:
|
85 |
+
predictions (list): complete model predictions
|
86 |
+
labels (list): golden labels (previously tokenized)
|
87 |
+
Returns:
|
88 |
+
results (dict): dictionary with the computed evaluation metrics
|
89 |
+
predictions (list): list of the decoded predictions of the model
|
90 |
+
"""
|
91 |
+
# extract feedback and labels from the model's predictions
|
92 |
+
predicted_feedback = extract_feedback(predictions)
|
93 |
+
predicted_labels = extract_labels(predictions)
|
94 |
+
|
95 |
+
# extract feedback and labels from the golden labels
|
96 |
+
reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
|
97 |
+
reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels]
|
98 |
+
|
99 |
+
# compute HF metrics
|
100 |
+
sacrebleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
|
101 |
+
rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
|
102 |
+
meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor']
|
103 |
+
bert_score = bertscore.compute(
|
104 |
+
predictions=predicted_feedback,
|
105 |
+
references=reference_feedback,
|
106 |
+
lang='en',
|
107 |
+
rescale_with_baseline=True)
|
108 |
+
|
109 |
+
# use sklearn to compute accuracy and f1 score
|
110 |
+
reference_labels_np = np.array(reference_labels)
|
111 |
+
accuracy = accuracy_score(reference_labels_np, predicted_labels)
|
112 |
+
f1_weighted = f1_score(reference_labels_np, predicted_labels, average='weighted')
|
113 |
+
f1_macro = f1_score(
|
114 |
+
reference_labels_np,
|
115 |
+
predicted_labels,
|
116 |
+
average='macro',
|
117 |
+
labels=['Incorrect', 'Partially correct', 'Correct'])
|
118 |
+
|
119 |
+
results = {
|
120 |
+
'sacrebleu': sacrebleu_score,
|
121 |
+
'rouge': rouge_score,
|
122 |
+
'meteor': meteor_score,
|
123 |
+
'bert_score': np.array(bert_score['f1']).mean().item(),
|
124 |
+
'accuracy': accuracy,
|
125 |
+
'f1_weighted': f1_weighted,
|
126 |
+
'f1_macro': f1_macro
|
127 |
+
}
|
128 |
+
|
129 |
+
return results
|
130 |
+
|
131 |
+
def evaluate(model, tokenizer, dataloader):
|
132 |
+
"""
|
133 |
+
Evaluate model on the given dataset
|
134 |
+
|
135 |
+
Params:
|
136 |
+
model (PreTrainedModel): seq2seq model
|
137 |
+
tokenizer (PreTrainedTokenizer): tokenizer from HuggingFace
|
138 |
+
dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation
|
139 |
+
Returns:
|
140 |
+
results (dict): dictionary with the computed evaluation metrics
|
141 |
+
predictions (list): list of the decoded predictions of the model
|
142 |
+
"""
|
143 |
+
decoded_preds, decoded_labels = [], []
|
144 |
+
|
145 |
+
model.eval()
|
146 |
+
# iterate through batchs in the dataloader
|
147 |
+
for batch in tqdm(dataloader):
|
148 |
+
with torch.no_grad():
|
149 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
150 |
+
# generate tokens from batch
|
151 |
+
generated_tokens = model.generate(
|
152 |
+
batch['input_ids'],
|
153 |
+
attention_mask=batch['attention_mask'],
|
154 |
+
max_length=MAX_TARGET_LENGTH
|
155 |
+
)
|
156 |
+
# get golden labels from batch
|
157 |
+
labels_batch = batch['labels']
|
158 |
+
|
159 |
+
# decode model predictions and golden labels
|
160 |
+
decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
161 |
+
decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True)
|
162 |
+
|
163 |
+
decoded_preds.append(decoded_preds_batch)
|
164 |
+
decoded_labels.append(decoded_labels_batch)
|
165 |
+
|
166 |
+
# convert predictions and golden labels into flattened lists
|
167 |
+
predictions = flatten_list(decoded_preds)
|
168 |
+
labels = flatten_list(decoded_labels)
|
169 |
+
|
170 |
+
# compute metrics based on predictions and golden labels
|
171 |
+
results = compute_metrics(predictions, labels)
|
172 |
+
|
173 |
+
return results, predictions
|