botcon commited on
Commit
8cf17a4
1 Parent(s): a6286de

Upload QuestionAnswering.py

Browse files
Files changed (1) hide show
  1. QuestionAnswering.py +476 -0
QuestionAnswering.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LukePreTrainedModel, LukeModel, AutoTokenizer, TrainingArguments, default_data_collator, Trainer, AutoModelForQuestionAnswering
2
+ from transformers.modeling_outputs import ModelOutput
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ import evaluate
8
+ import torch
9
+ from dataclasses import dataclass
10
+ from datasets import load_dataset, concatenate_datasets
11
+ from torch import nn
12
+ from torch.nn import CrossEntropyLoss
13
+ import collections
14
+
15
+ PEFT = False
16
+ tf32 = True
17
+ fp16= True
18
+ train = False
19
+ test = True
20
+ trained_model = "LUKE_squadshift"
21
+ train_checkpoint = None
22
+
23
+ base_tokenizer = "roberta-base"
24
+ base_model = "studio-ousia/luke-base"
25
+
26
+ # base_tokenizer = "xlnet-base-cased"
27
+ # base_model = "xlnet-base-cased"
28
+
29
+ # base_tokenizer = "bert-base-cased"
30
+ # base_model = "SpanBERT/spanbert-base-cased"
31
+
32
+ torch.backends.cuda.matmul.allow_tf32 = tf32
33
+ torch.backends.cudnn.allow_tf32 = tf32
34
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
35
+
36
+ if tf32:
37
+ trained_model += "_tf32"
38
+
39
+ # https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/luke/modeling_luke.py#L319-L353
40
+ # Taken from HF repository, easier to include additional features -- Currently identical to LukeForQuestionAnswering by HF
41
+
42
+ @dataclass
43
+ class LukeQuestionAnsweringModelOutput(ModelOutput):
44
+ """
45
+ Outputs of question answering models.
46
+
47
+
48
+ Args:
49
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
50
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
51
+ start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
52
+ Span-start scores (before SoftMax).
53
+ end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
54
+ Span-end scores (before SoftMax).
55
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
56
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
57
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
58
+
59
+
60
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
61
+ entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
62
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
63
+ shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
64
+ layer plus the initial entity embedding outputs.
65
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
66
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
67
+ sequence_length)`.
68
+
69
+
70
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
71
+ heads.
72
+ """
73
+
74
+
75
+ loss: Optional[torch.FloatTensor] = None
76
+ start_logits: torch.FloatTensor = None
77
+ end_logits: torch.FloatTensor = None
78
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
79
+ entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
80
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
81
+
82
+ class AugmentedLukeForQuestionAnswering(LukePreTrainedModel):
83
+ def __init__(self, config):
84
+ super().__init__(config)
85
+
86
+ # This is 2.
87
+ self.num_labels = config.num_labels
88
+
89
+ self.luke = LukeModel(config, add_pooling_layer=False)
90
+
91
+ '''
92
+ Any improvement to the model are expected here. Additional features, anything...
93
+ '''
94
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
95
+
96
+
97
+ # Initialize weights and apply final processing
98
+ self.post_init()
99
+
100
+ def forward(
101
+ self,
102
+ input_ids: Optional[torch.LongTensor] = None,
103
+ attention_mask: Optional[torch.FloatTensor] = None,
104
+ token_type_ids: Optional[torch.LongTensor] = None,
105
+ position_ids: Optional[torch.FloatTensor] = None,
106
+ entity_ids: Optional[torch.LongTensor] = None,
107
+ entity_attention_mask: Optional[torch.FloatTensor] = None,
108
+ entity_token_type_ids: Optional[torch.LongTensor] = None,
109
+ entity_position_ids: Optional[torch.LongTensor] = None,
110
+ head_mask: Optional[torch.FloatTensor] = None,
111
+ inputs_embeds: Optional[torch.FloatTensor] = None,
112
+ start_positions: Optional[torch.LongTensor] = None,
113
+ end_positions: Optional[torch.LongTensor] = None,
114
+ output_attentions: Optional[bool] = None,
115
+ output_hidden_states: Optional[bool] = None,
116
+ return_dict: Optional[bool] = None,
117
+ ) -> Union[Tuple, LukeQuestionAnsweringModelOutput]:
118
+
119
+ r"""
120
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
121
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
122
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
123
+ are not taken into account for computing the loss.
124
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
125
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
126
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
127
+ are not taken into account for computing the loss.
128
+ """
129
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
130
+
131
+
132
+ outputs = self.luke(
133
+ input_ids=input_ids,
134
+ attention_mask=attention_mask,
135
+ token_type_ids=token_type_ids,
136
+ position_ids=position_ids,
137
+ entity_ids=entity_ids,
138
+ entity_attention_mask=entity_attention_mask,
139
+ entity_token_type_ids=entity_token_type_ids,
140
+ entity_position_ids=entity_position_ids,
141
+ head_mask=head_mask,
142
+ inputs_embeds=inputs_embeds,
143
+ output_attentions=output_attentions,
144
+ output_hidden_states=output_hidden_states,
145
+ return_dict=True,
146
+ )
147
+
148
+
149
+ sequence_output = outputs.last_hidden_state
150
+
151
+
152
+ logits = self.qa_outputs(sequence_output)
153
+ start_logits, end_logits = logits.split(1, dim=-1)
154
+ start_logits = start_logits.squeeze(-1)
155
+ end_logits = end_logits.squeeze(-1)
156
+
157
+
158
+ total_loss = None
159
+ if start_positions is not None and end_positions is not None:
160
+ # If we are on multi-GPU, split add a dimension
161
+ if len(start_positions.size()) > 1:
162
+ start_positions = start_positions.squeeze(-1)
163
+ if len(end_positions.size()) > 1:
164
+ end_positions = end_positions.squeeze(-1)
165
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
166
+ ignored_index = start_logits.size(1)
167
+ start_positions.clamp_(0, ignored_index)
168
+ end_positions.clamp_(0, ignored_index)
169
+
170
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
171
+ start_loss = loss_fct(start_logits, start_positions)
172
+ end_loss = loss_fct(end_logits, end_positions)
173
+ total_loss = (start_loss + end_loss) / 2
174
+
175
+
176
+ if not return_dict:
177
+ return tuple(
178
+ v
179
+ for v in [
180
+ total_loss,
181
+ start_logits,
182
+ end_logits,
183
+ outputs.hidden_states,
184
+ outputs.entity_hidden_states,
185
+ outputs.attentions,
186
+ ]
187
+ if v is not None
188
+ )
189
+
190
+
191
+ return LukeQuestionAnsweringModelOutput(
192
+ loss=total_loss,
193
+ start_logits=start_logits,
194
+ end_logits=end_logits,
195
+ hidden_states=outputs.hidden_states,
196
+ entity_hidden_states=outputs.entity_hidden_states,
197
+ attentions=outputs.attentions,
198
+ )
199
+
200
+ # Get data to train model - squadshift is designed as a validation/testing set, so there are multiple answers, take the shortest
201
+ def get_squadshifts_training():
202
+ wiki = load_dataset("squadshifts", "new_wiki")["test"]
203
+ nyt = load_dataset("squadshifts", "nyt")["test"]
204
+ reddit = load_dataset("squadshifts", "reddit")["test"]
205
+ raw_dataset = concatenate_datasets([wiki, nyt, reddit])
206
+ updated = raw_dataset.map(validation_to_train)
207
+ return updated
208
+
209
+ def validation_to_train(example):
210
+ answers = example["answers"]
211
+ answer_text = answers["text"]
212
+ index_min = min(range(len(answer_text)), key=lambda x : len(answer_text.__getitem__(x)))
213
+ answers["text"] = answers["text"][index_min:index_min+1]
214
+ answers["answer_start"] = answers["answer_start"][index_min:index_min+1]
215
+ return example
216
+
217
+ if __name__ == "__main__":
218
+ # Setting up tokenizer and helper functions
219
+ # Work-around for FastTokenizer - RoBERTa and LUKE share the same subword vocab, and we are not using entities functions of LUKE-tokenizer anyways
220
+ tokenizer = AutoTokenizer.from_pretrained(base_tokenizer)
221
+
222
+ # Necessary initialization
223
+ max_length = 500
224
+ stride = 128
225
+ batch_size = 8
226
+ n_best = 20
227
+ max_answer_length = 30
228
+ metric = evaluate.load("squad")
229
+ raw_datasets = load_dataset("squad")
230
+
231
+ raw_train = raw_datasets["train"]
232
+ raw_validation = raw_datasets["validation"]
233
+
234
+ def compute_metrics(start_logits, end_logits, features, examples):
235
+ example_to_features = collections.defaultdict(list)
236
+ for idx, feature in enumerate(features):
237
+ example_to_features[feature["example_id"]].append(idx)
238
+
239
+ predicted_answers = []
240
+ for example in tqdm(examples):
241
+ example_id = example["id"]
242
+ context = example["context"]
243
+ answers = []
244
+
245
+ # Loop through all features associated with that example
246
+ for feature_index in example_to_features[example_id]:
247
+ start_logit = start_logits[feature_index]
248
+ end_logit = end_logits[feature_index]
249
+ offsets = features[feature_index]["offset_mapping"]
250
+
251
+ start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
252
+ end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
253
+ for start_index in start_indexes:
254
+ for end_index in end_indexes:
255
+ # Skip answers that are not fully in the context
256
+ if offsets[start_index] is None or offsets[end_index] is None:
257
+ continue
258
+ # Skip answers with a length that is either < 0 or > max_answer_length
259
+ if (
260
+ end_index < start_index
261
+ or end_index - start_index + 1 > max_answer_length
262
+ ):
263
+ continue
264
+
265
+ answer = {
266
+ "text": context[offsets[start_index][0] : offsets[end_index][1]],
267
+ "logit_score": start_logit[start_index] + end_logit[end_index],
268
+ }
269
+ answers.append(answer)
270
+
271
+ # Select the answer with the best score
272
+ if len(answers) > 0:
273
+ best_answer = max(answers, key=lambda x: x["logit_score"])
274
+ predicted_answers.append(
275
+ {"id": example_id, "prediction_text": best_answer["text"]}
276
+ )
277
+ else:
278
+ predicted_answers.append({"id": example_id, "prediction_text": ""})
279
+
280
+ theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
281
+ return metric.compute(predictions=predicted_answers, references=theoretical_answers)
282
+
283
+ def preprocess_training_examples(examples):
284
+
285
+ questions = [q.strip() for q in examples["question"]]
286
+ inputs = tokenizer(
287
+ questions,
288
+ examples["context"],
289
+ max_length=max_length,
290
+ truncation="only_second",
291
+ stride=stride,
292
+ return_overflowing_tokens=True,
293
+ return_offsets_mapping=True,
294
+ padding="max_length",
295
+ )
296
+
297
+ offset_mapping = inputs.pop("offset_mapping")
298
+ sample_map = inputs.pop("overflow_to_sample_mapping")
299
+ answers = examples["answers"]
300
+ start_positions = []
301
+ end_positions = []
302
+
303
+ for i, offset in enumerate(offset_mapping):
304
+ sample_idx = sample_map[i]
305
+ answer = answers[sample_idx]
306
+ start_char = answer["answer_start"][0]
307
+ end_char = answer["answer_start"][0] + len(answer["text"][0])
308
+ sequence_ids = inputs.sequence_ids(i)
309
+
310
+ # Find the start and end of the context
311
+ idx = 0
312
+ while sequence_ids[idx] != 1:
313
+ idx += 1
314
+ context_start = idx
315
+ while sequence_ids[idx] == 1:
316
+ idx += 1
317
+ context_end = idx - 1
318
+
319
+ # If the answer is not fully inside the context, label is (0, 0)
320
+ if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
321
+ start_positions.append(0)
322
+ end_positions.append(0)
323
+ else:
324
+ # Otherwise it's the start and end token positions
325
+ idx = context_start
326
+ while idx <= context_end and offset[idx][0] <= start_char:
327
+ idx += 1
328
+ start_positions.append(idx - 1)
329
+
330
+ idx = context_end
331
+ while idx >= context_start and offset[idx][1] >= end_char:
332
+ idx -= 1
333
+ end_positions.append(idx + 1)
334
+
335
+ inputs["start_positions"] = start_positions
336
+ inputs["end_positions"] = end_positions
337
+ return inputs
338
+
339
+ def preprocess_validation_examples(examples):
340
+ questions = [q.strip() for q in examples["question"]]
341
+ inputs = tokenizer(
342
+ questions,
343
+ examples["context"],
344
+ max_length=max_length,
345
+ truncation="only_second",
346
+ stride=stride,
347
+ return_overflowing_tokens=True,
348
+ return_offsets_mapping=True,
349
+ padding="max_length",
350
+ )
351
+
352
+
353
+ sample_map = inputs.pop("overflow_to_sample_mapping")
354
+ example_ids = []
355
+
356
+ for i in range(len(inputs["input_ids"])):
357
+ sample_idx = sample_map[i]
358
+ example_ids.append(examples["id"][sample_idx])
359
+
360
+ sequence_ids = inputs.sequence_ids(i)
361
+ offset = inputs["offset_mapping"][i]
362
+ inputs["offset_mapping"][i] = [
363
+ o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
364
+ ]
365
+
366
+ inputs["example_id"] = example_ids
367
+ return inputs
368
+
369
+ if train:
370
+
371
+ model = AutoModelForQuestionAnswering.from_pretrained(base_model).to(device)
372
+
373
+ # For squadshift
374
+ raw_train = get_squadshifts_training()
375
+
376
+ train_dataset = raw_train.map(
377
+ preprocess_training_examples,
378
+ batched=True,
379
+ remove_columns=raw_train.column_names,
380
+ )
381
+
382
+ validation_dataset = raw_validation.map(
383
+ preprocess_validation_examples,
384
+ batched=True,
385
+ remove_columns=raw_validation.column_names,
386
+ )
387
+
388
+
389
+
390
+ # --------------- PEFT -------------------- # One epoch without PEFT took about 2h on my computer with CUDA - performance of PEFT kinda ass though
391
+ if PEFT:
392
+ from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
393
+
394
+ # ---- For all linear layers ----
395
+ import re
396
+ pattern = r'\((\w+)\): Linear'
397
+ linear_layers = re.findall(pattern, str(model.modules))
398
+ target_modules = list(set(linear_layers))
399
+
400
+ # If using peft, can consider increaisng r for better performance
401
+ peft_config = LoraConfig(
402
+ task_type=TaskType.QUESTION_ANS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=target_modules, bias='all'
403
+ )
404
+
405
+ model = get_peft_model(model, peft_config)
406
+ model.print_trainable_parameters()
407
+
408
+ trained_model += "_PEFT"
409
+
410
+ # ------------------------------------------ #
411
+
412
+ args = TrainingArguments(
413
+ trained_model,
414
+ evaluation_strategy = "no",
415
+ save_strategy="epoch",
416
+ learning_rate=2e-5,
417
+ per_device_train_batch_size=batch_size,
418
+ per_device_eval_batch_size=batch_size,
419
+ num_train_epochs=3,
420
+ weight_decay=0.01,
421
+ push_to_hub=True,
422
+ fp16=fp16
423
+ )
424
+
425
+ trainer = Trainer(
426
+ model,
427
+ args,
428
+ train_dataset=train_dataset,
429
+ eval_dataset=validation_dataset,
430
+ data_collator=default_data_collator,
431
+ tokenizer=tokenizer
432
+ )
433
+
434
+ trainer.train(train_checkpoint)
435
+
436
+ if test:
437
+ model = AutoModelForQuestionAnswering.from_pretrained(trained_model).to(device)
438
+
439
+ interval = len(raw_datasets["validation"]) // 100
440
+ exact_match = 0
441
+ f1 = 0
442
+
443
+ with torch.no_grad():
444
+ for i in range(1, 101):
445
+ start = interval * (i - 1)
446
+ end = interval * i
447
+ small_eval_set = raw_datasets["validation"].select(range(start ,end))
448
+ eval_set = small_eval_set.map(
449
+ preprocess_validation_examples,
450
+ batched=True,
451
+ remove_columns=raw_datasets["validation"].column_names
452
+ )
453
+ eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
454
+ eval_set_for_model.set_format("torch")
455
+ batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
456
+ outputs = model(**batch)
457
+ start_logits = outputs.start_logits.cpu().numpy()
458
+ end_logits = outputs.end_logits.cpu().numpy()
459
+ res = compute_metrics(start_logits, end_logits, eval_set, small_eval_set)
460
+ exact_match += res['exact_match']
461
+ f1 += res["f1"]
462
+
463
+ print("F1 score: {}".format(f1 / 100))
464
+ print("Exact match: {}".format(exact_match / 100))
465
+
466
+ # XLNET
467
+ # F1 score: 91.54154256653278
468
+ # Exact match: 84.86666666666666
469
+
470
+ # SpanBERT
471
+ # F1 score: 92.160285362531
472
+ # Exact match: 85.73333333333333
473
+
474
+ # LUKE SQUADSHIFT (SQUAD then SQUADSHIFT)
475
+ # F1 score: 91.27683543983473
476
+ # Exact match: 84.96190476190473