JohnnyBoy00 commited on
Commit
1b7c795
·
1 Parent(s): 60d919b

Upload evaluation.py

Browse files
Files changed (1) hide show
  1. evaluation.py +173 -0
evaluation.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from evaluate import load as load_metric
5
+
6
+ from sklearn.metrics import accuracy_score, f1_score
7
+ from tqdm.auto import tqdm
8
+
9
+ MAX_TARGET_LENGTH = 128
10
+
11
+ # load evaluation metrics
12
+ sacrebleu = load_metric('sacrebleu')
13
+ rouge = load_metric('rouge')
14
+ meteor = load_metric('meteor')
15
+ bertscore = load_metric('bertscore')
16
+
17
+ # use gpu if it's available
18
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
19
+
20
+ def flatten_list(l):
21
+ """
22
+ Utility function to convert a list of lists into a flattened list
23
+
24
+ Params:
25
+ l (list of lists): list to be flattened
26
+ Returns:
27
+ A flattened list with the elements of the original list
28
+ """
29
+ return [item for sublist in l for item in sublist]
30
+
31
+ def extract_feedback(predictions):
32
+ """
33
+ Utility function to extract the feedback from the predictions of the model
34
+
35
+ Params:
36
+ predictions (list): complete model predictions
37
+ Returns:
38
+ feedback (list): extracted feedback from the model's predictions
39
+ """
40
+ feedback = []
41
+ # iterate through predictions and try to extract predicted feedback
42
+ for pred in predictions:
43
+ try:
44
+ fb = pred.split(':', 1)[1]
45
+ except IndexError:
46
+ try:
47
+ if pred.lower().startswith('partially correct'):
48
+ fb = pred.split(' ', 1)[2]
49
+ else:
50
+ fb = pred.split(' ', 1)[1]
51
+ except IndexError:
52
+ fb = pred
53
+ feedback.append(fb.strip())
54
+
55
+ return feedback
56
+
57
+ def extract_labels(predictions):
58
+ """
59
+ Utility function to extract the labels from the predictions of the model
60
+
61
+ Params:
62
+ predictions (list): complete model predictions
63
+ Returns:
64
+ feedback (list): extracted labels from the model's predictions
65
+ """
66
+ labels = []
67
+ for pred in predictions:
68
+ if pred.lower().startswith('correct'):
69
+ label = 'Correct'
70
+ elif pred.lower().startswith('partially correct'):
71
+ label = 'Partially correct'
72
+ elif pred.lower().startswith('incorrect'):
73
+ label = 'Incorrect'
74
+ else:
75
+ label = 'Unknown label'
76
+ labels.append(label)
77
+
78
+ return labels
79
+
80
+ def compute_metrics(predictions, labels):
81
+ """
82
+ Compute evaluation metrics from the predictions of the model
83
+
84
+ Params:
85
+ predictions (list): complete model predictions
86
+ labels (list): golden labels (previously tokenized)
87
+ Returns:
88
+ results (dict): dictionary with the computed evaluation metrics
89
+ predictions (list): list of the decoded predictions of the model
90
+ """
91
+ # extract feedback and labels from the model's predictions
92
+ predicted_feedback = extract_feedback(predictions)
93
+ predicted_labels = extract_labels(predictions)
94
+
95
+ # extract feedback and labels from the golden labels
96
+ reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
97
+ reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels]
98
+
99
+ # compute HF metrics
100
+ sacrebleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
101
+ rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
102
+ meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor']
103
+ bert_score = bertscore.compute(
104
+ predictions=predicted_feedback,
105
+ references=reference_feedback,
106
+ lang='en',
107
+ rescale_with_baseline=True)
108
+
109
+ # use sklearn to compute accuracy and f1 score
110
+ reference_labels_np = np.array(reference_labels)
111
+ accuracy = accuracy_score(reference_labels_np, predicted_labels)
112
+ f1_weighted = f1_score(reference_labels_np, predicted_labels, average='weighted')
113
+ f1_macro = f1_score(
114
+ reference_labels_np,
115
+ predicted_labels,
116
+ average='macro',
117
+ labels=['Incorrect', 'Partially correct', 'Correct'])
118
+
119
+ results = {
120
+ 'sacrebleu': sacrebleu_score,
121
+ 'rouge': rouge_score,
122
+ 'meteor': meteor_score,
123
+ 'bert_score': np.array(bert_score['f1']).mean().item(),
124
+ 'accuracy': accuracy,
125
+ 'f1_weighted': f1_weighted,
126
+ 'f1_macro': f1_macro
127
+ }
128
+
129
+ return results
130
+
131
+ def evaluate(model, tokenizer, dataloader):
132
+ """
133
+ Evaluate model on the given dataset
134
+
135
+ Params:
136
+ model (PreTrainedModel): seq2seq model
137
+ tokenizer (PreTrainedTokenizer): tokenizer from HuggingFace
138
+ dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation
139
+ Returns:
140
+ results (dict): dictionary with the computed evaluation metrics
141
+ predictions (list): list of the decoded predictions of the model
142
+ """
143
+ decoded_preds, decoded_labels = [], []
144
+
145
+ model.eval()
146
+ # iterate through batchs in the dataloader
147
+ for batch in tqdm(dataloader):
148
+ with torch.no_grad():
149
+ batch = {k: v.to(device) for k, v in batch.items()}
150
+ # generate tokens from batch
151
+ generated_tokens = model.generate(
152
+ batch['input_ids'],
153
+ attention_mask=batch['attention_mask'],
154
+ max_length=MAX_TARGET_LENGTH
155
+ )
156
+ # get golden labels from batch
157
+ labels_batch = batch['labels']
158
+
159
+ # decode model predictions and golden labels
160
+ decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
161
+ decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True)
162
+
163
+ decoded_preds.append(decoded_preds_batch)
164
+ decoded_labels.append(decoded_labels_batch)
165
+
166
+ # convert predictions and golden labels into flattened lists
167
+ predictions = flatten_list(decoded_preds)
168
+ labels = flatten_list(decoded_labels)
169
+
170
+ # compute metrics based on predictions and golden labels
171
+ results = compute_metrics(predictions, labels)
172
+
173
+ return results, predictions