botcon commited on
Commit
60cea80
1 Parent(s): 0033f23

Upload LukeQuestionAnswering.py

Browse files
Files changed (1) hide show
  1. LukeQuestionAnswering.py +340 -0
LukeQuestionAnswering.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LukePreTrainedModel, LukeModel, AutoTokenizer, TrainingArguments, default_data_collator, Trainer, LukeForQuestionAnswering
2
+ from transformers.modeling_outputs import ModelOutput
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ from dataclasses import dataclass
7
+ from datasets import load_dataset
8
+ from torch import nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+
11
+ PEFT = True
12
+ repo_name = "LUKE_squad_finetuned_qa"
13
+
14
+ # https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/luke/modeling_luke.py#L319-L353
15
+ # Taken from HF repository, easier to include additional features -- Currently identical to LukeForQuestionAnswering by HF
16
+
17
+ @dataclass
18
+ class LukeQuestionAnsweringModelOutput(ModelOutput):
19
+ """
20
+ Outputs of question answering models.
21
+
22
+
23
+ Args:
24
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
25
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
26
+ start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
27
+ Span-start scores (before SoftMax).
28
+ end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
29
+ Span-end scores (before SoftMax).
30
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
31
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
32
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
33
+
34
+
35
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
36
+ entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
37
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
38
+ shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
39
+ layer plus the initial entity embedding outputs.
40
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
41
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
42
+ sequence_length)`.
43
+
44
+
45
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
46
+ heads.
47
+ """
48
+
49
+
50
+ loss: Optional[torch.FloatTensor] = None
51
+ start_logits: torch.FloatTensor = None
52
+ end_logits: torch.FloatTensor = None
53
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
54
+ entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
55
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
56
+
57
+ class AugmentedLukeForQuestionAnswering(LukePreTrainedModel):
58
+ def __init__(self, config):
59
+ super().__init__(config)
60
+
61
+ # This is 2.
62
+ self.num_labels = config.num_labels
63
+
64
+ self.luke = LukeModel(config, add_pooling_layer=False)
65
+
66
+ '''
67
+ Any improvement to the model are expected here. Additional features, anything...
68
+ '''
69
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
70
+
71
+
72
+ # Initialize weights and apply final processing
73
+ self.post_init()
74
+
75
+ def forward(
76
+ self,
77
+ input_ids: Optional[torch.LongTensor] = None,
78
+ attention_mask: Optional[torch.FloatTensor] = None,
79
+ token_type_ids: Optional[torch.LongTensor] = None,
80
+ position_ids: Optional[torch.FloatTensor] = None,
81
+ entity_ids: Optional[torch.LongTensor] = None,
82
+ entity_attention_mask: Optional[torch.FloatTensor] = None,
83
+ entity_token_type_ids: Optional[torch.LongTensor] = None,
84
+ entity_position_ids: Optional[torch.LongTensor] = None,
85
+ head_mask: Optional[torch.FloatTensor] = None,
86
+ inputs_embeds: Optional[torch.FloatTensor] = None,
87
+ start_positions: Optional[torch.LongTensor] = None,
88
+ end_positions: Optional[torch.LongTensor] = None,
89
+ output_attentions: Optional[bool] = None,
90
+ output_hidden_states: Optional[bool] = None,
91
+ return_dict: Optional[bool] = None,
92
+ ) -> Union[Tuple, LukeQuestionAnsweringModelOutput]:
93
+
94
+ r"""
95
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
96
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
97
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
98
+ are not taken into account for computing the loss.
99
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
100
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
101
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
102
+ are not taken into account for computing the loss.
103
+ """
104
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
105
+
106
+
107
+ outputs = self.luke(
108
+ input_ids=input_ids,
109
+ attention_mask=attention_mask,
110
+ token_type_ids=token_type_ids,
111
+ position_ids=position_ids,
112
+ entity_ids=entity_ids,
113
+ entity_attention_mask=entity_attention_mask,
114
+ entity_token_type_ids=entity_token_type_ids,
115
+ entity_position_ids=entity_position_ids,
116
+ head_mask=head_mask,
117
+ inputs_embeds=inputs_embeds,
118
+ output_attentions=output_attentions,
119
+ output_hidden_states=output_hidden_states,
120
+ return_dict=True,
121
+ )
122
+
123
+
124
+ sequence_output = outputs.last_hidden_state
125
+
126
+
127
+ logits = self.qa_outputs(sequence_output)
128
+ start_logits, end_logits = logits.split(1, dim=-1)
129
+ start_logits = start_logits.squeeze(-1)
130
+ end_logits = end_logits.squeeze(-1)
131
+
132
+
133
+ total_loss = None
134
+ if start_positions is not None and end_positions is not None:
135
+ # If we are on multi-GPU, split add a dimension
136
+ if len(start_positions.size()) > 1:
137
+ start_positions = start_positions.squeeze(-1)
138
+ if len(end_positions.size()) > 1:
139
+ end_positions = end_positions.squeeze(-1)
140
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
141
+ ignored_index = start_logits.size(1)
142
+ start_positions.clamp_(0, ignored_index)
143
+ end_positions.clamp_(0, ignored_index)
144
+
145
+
146
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
147
+ start_loss = loss_fct(start_logits, start_positions)
148
+ end_loss = loss_fct(end_logits, end_positions)
149
+ total_loss = (start_loss + end_loss) / 2
150
+
151
+
152
+ if not return_dict:
153
+ return tuple(
154
+ v
155
+ for v in [
156
+ total_loss,
157
+ start_logits,
158
+ end_logits,
159
+ outputs.hidden_states,
160
+ outputs.entity_hidden_states,
161
+ outputs.attentions,
162
+ ]
163
+ if v is not None
164
+ )
165
+
166
+
167
+ return LukeQuestionAnsweringModelOutput(
168
+ loss=total_loss,
169
+ start_logits=start_logits,
170
+ end_logits=end_logits,
171
+ hidden_states=outputs.hidden_states,
172
+ entity_hidden_states=outputs.entity_hidden_states,
173
+ attentions=outputs.attentions,
174
+ )
175
+
176
+ if __name__ == "__main__":
177
+ base_luke = "studio-ousia/luke-base"
178
+
179
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
180
+
181
+ # Luke does not have a FastTokenizer
182
+ # Work-around for FastTokenizer - RoBERTa and LUKE share the same subword vocab, and we are not using entities functions of LUKE-tokenizer anyways
183
+ tokenizer = AutoTokenizer.from_pretrained("roberta-base")
184
+
185
+ # tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
186
+ model = AugmentedLukeForQuestionAnswering.from_pretrained(base_luke).to(device)
187
+
188
+ raw_datasets = load_dataset("squad")
189
+
190
+ # not exactly hyperparameters
191
+ max_length = 384
192
+ stride = 128
193
+ batch_size = 3
194
+
195
+ def preprocess_training_examples(examples):
196
+
197
+ questions = [q.strip() for q in examples["question"]]
198
+ inputs = tokenizer(
199
+ questions,
200
+ examples["context"],
201
+ max_length=max_length,
202
+ truncation="only_second",
203
+ stride=stride,
204
+ return_overflowing_tokens=True,
205
+ return_offsets_mapping=True,
206
+ padding="max_length",
207
+ )
208
+
209
+ offset_mapping = inputs.pop("offset_mapping")
210
+ sample_map = inputs.pop("overflow_to_sample_mapping")
211
+ answers = examples["answers"]
212
+ start_positions = []
213
+ end_positions = []
214
+
215
+ for i, offset in enumerate(offset_mapping):
216
+ sample_idx = sample_map[i]
217
+ answer = answers[sample_idx]
218
+ start_char = answer["answer_start"][0]
219
+ end_char = answer["answer_start"][0] + len(answer["text"][0])
220
+ sequence_ids = inputs.sequence_ids(i)
221
+
222
+ # Find the start and end of the context
223
+ idx = 0
224
+ while sequence_ids[idx] != 1:
225
+ idx += 1
226
+ context_start = idx
227
+ while sequence_ids[idx] == 1:
228
+ idx += 1
229
+ context_end = idx - 1
230
+
231
+ # If the answer is not fully inside the context, label is (0, 0)
232
+ if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
233
+ start_positions.append(0)
234
+ end_positions.append(0)
235
+ else:
236
+ # Otherwise it's the start and end token positions
237
+ idx = context_start
238
+ while idx <= context_end and offset[idx][0] <= start_char:
239
+ idx += 1
240
+ start_positions.append(idx - 1)
241
+
242
+ idx = context_end
243
+ while idx >= context_start and offset[idx][1] >= end_char:
244
+ idx -= 1
245
+ end_positions.append(idx + 1)
246
+
247
+ inputs["start_positions"] = start_positions
248
+ inputs["end_positions"] = end_positions
249
+ return inputs
250
+
251
+ train_dataset = raw_datasets["train"].map(
252
+ preprocess_training_examples,
253
+ batched=True,
254
+ remove_columns=raw_datasets["train"].column_names,
255
+ )
256
+
257
+ def preprocess_validation_examples(examples):
258
+ questions = [q.strip() for q in examples["question"]]
259
+ inputs = tokenizer(
260
+ questions,
261
+ examples["context"],
262
+ max_length=max_length,
263
+ truncation="only_second",
264
+ stride=stride,
265
+ return_overflowing_tokens=True,
266
+ return_offsets_mapping=True,
267
+ padding="max_length",
268
+ )
269
+
270
+
271
+ sample_map = inputs.pop("overflow_to_sample_mapping")
272
+ example_ids = []
273
+
274
+ for i in range(len(inputs["input_ids"])):
275
+ sample_idx = sample_map[i]
276
+ example_ids.append(examples["id"][sample_idx])
277
+
278
+ sequence_ids = inputs.sequence_ids(i)
279
+ offset = inputs["offset_mapping"][i]
280
+ inputs["offset_mapping"][i] = [
281
+ o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
282
+ ]
283
+
284
+ inputs["example_id"] = example_ids
285
+ return inputs
286
+
287
+ validation_dataset = raw_datasets["validation"].map(
288
+ preprocess_validation_examples,
289
+ batched=True,
290
+ remove_columns=raw_datasets["validation"].column_names,
291
+ )
292
+
293
+ # --------------- PEFT -------------------- # One epoch without PEFT took about 2h on my computer with CUDA - performance of PEFT kinda ass though
294
+ if PEFT:
295
+ from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
296
+
297
+ # ---- For all linear layers ----
298
+ import re
299
+ pattern = r'\((\w+)\): Linear'
300
+ linear_layers = re.findall(pattern, str(model.modules))
301
+ target_modules = list(set(linear_layers))
302
+
303
+ # If using peft, can consider increaisng r for better performance
304
+ peft_config = LoraConfig(
305
+ task_type=TaskType.QUESTION_ANS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=target_modules, bias='all'
306
+ )
307
+
308
+ model = get_peft_model(model, peft_config)
309
+ model.print_trainable_parameters()
310
+
311
+ repo_name += "_PEFT"
312
+
313
+ # ------------------------------------------ #
314
+
315
+ args = TrainingArguments(
316
+ repo_name,
317
+ evaluation_strategy = "no",
318
+ save_strategy="epoch",
319
+ learning_rate=2e-5,
320
+ per_device_train_batch_size=batch_size,
321
+ per_device_eval_batch_size=batch_size,
322
+ num_train_epochs=3,
323
+ weight_decay=0.01,
324
+ push_to_hub=True,
325
+ )
326
+
327
+ trainer = Trainer(
328
+ model,
329
+ args,
330
+ train_dataset=train_dataset,
331
+ eval_dataset=validation_dataset,
332
+ data_collator=default_data_collator,
333
+ tokenizer=tokenizer
334
+ )
335
+
336
+ trainer.train()
337
+
338
+
339
+ # Not complete yet, still have post-processing, using HFHub to get results now
340
+ # https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt