Optimized wrapper with correct API
Browse files- modeling_norbert.py +64 -36
modeling_norbert.py
CHANGED
@@ -1,12 +1,9 @@
|
|
1 |
-
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
-
|
3 |
import math
|
4 |
from typing import List, Optional, Tuple, Union
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
9 |
-
from torch import _softmax_backward_data as _softmax_backward_data
|
10 |
from torch.utils import checkpoint
|
11 |
|
12 |
from configuration_norbert import NorbertConfig
|
@@ -20,6 +17,7 @@ from transformers.modeling_outputs import (
|
|
20 |
TokenClassifierOutput,
|
21 |
BaseModelOutput
|
22 |
)
|
|
|
23 |
|
24 |
|
25 |
class Encoder(nn.Module):
|
@@ -130,8 +128,8 @@ class MaskedSoftmax(torch.autograd.Function):
|
|
130 |
@staticmethod
|
131 |
def backward(self, grad_output):
|
132 |
output, = self.saved_tensors
|
133 |
-
|
134 |
-
return
|
135 |
|
136 |
|
137 |
class Attention(nn.Module):
|
@@ -188,31 +186,36 @@ class Attention(nn.Module):
|
|
188 |
if self.position_indices.size(0) < query_len:
|
189 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
190 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
191 |
-
position_indices = self.make_log_bucket_position(position_indices, self.
|
192 |
-
position_indices = self.
|
193 |
-
self.
|
194 |
|
195 |
hidden_states = self.pre_layer_norm(hidden_states)
|
196 |
|
197 |
query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
|
198 |
value = self.in_proj_v(hidden_states) # shape: [T, B, D]
|
199 |
|
200 |
-
pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
|
201 |
-
pos = F.embedding(self.position_indices[:query_len, :key_len], pos) # shape: [T, T, 2D]
|
202 |
-
pos = pos.view(query_len, key_len, self.num_heads, 2*self.head_size)
|
203 |
-
query_pos, key_pos = pos.chunk(2, dim=3)
|
204 |
-
|
205 |
query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
206 |
key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
207 |
value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
208 |
|
209 |
attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
|
210 |
|
|
|
|
|
211 |
query = query.view(batch_size, self.num_heads, query_len, self.head_size)
|
212 |
key = key.view(batch_size, self.num_heads, query_len, self.head_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
|
214 |
-
attention_scores.add_(
|
215 |
-
attention_scores.add_(
|
216 |
|
217 |
return attention_scores, value
|
218 |
|
@@ -332,12 +335,16 @@ class NorbertModel(NorbertPreTrainedModel):
|
|
332 |
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
|
333 |
|
334 |
if not return_dict:
|
335 |
-
return
|
|
|
|
|
|
|
|
|
336 |
|
337 |
return BaseModelOutput(
|
338 |
last_hidden_state=sequence_output,
|
339 |
-
hidden_states=contextualized_embeddings,
|
340 |
-
attentions=attention_probs
|
341 |
)
|
342 |
|
343 |
|
@@ -375,14 +382,18 @@ class NorbertForMaskedLM(NorbertModel):
|
|
375 |
masked_lm_loss = F.cross_entropy(subword_prediction.flatten(0, 1), labels.flatten())
|
376 |
|
377 |
if not return_dict:
|
378 |
-
output = (
|
|
|
|
|
|
|
|
|
379 |
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
380 |
|
381 |
return MaskedLMOutput(
|
382 |
loss=masked_lm_loss,
|
383 |
logits=subword_prediction,
|
384 |
-
hidden_states=contextualized_embeddings,
|
385 |
-
attentions=attention_probs
|
386 |
)
|
387 |
|
388 |
|
@@ -465,14 +476,18 @@ class NorbertForSequenceClassification(NorbertModel):
|
|
465 |
loss = loss_fct(logits, labels)
|
466 |
|
467 |
if not return_dict:
|
468 |
-
output = (
|
|
|
|
|
|
|
|
|
469 |
return ((loss,) + output) if loss is not None else output
|
470 |
|
471 |
return SequenceClassifierOutput(
|
472 |
loss=loss,
|
473 |
logits=logits,
|
474 |
-
hidden_states=contextualized_embeddings,
|
475 |
-
attentions=attention_probs
|
476 |
)
|
477 |
|
478 |
|
@@ -508,14 +523,18 @@ class NorbertForTokenClassification(NorbertModel):
|
|
508 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
509 |
|
510 |
if not return_dict:
|
511 |
-
output = (
|
|
|
|
|
|
|
|
|
512 |
return ((loss,) + output) if loss is not None else output
|
513 |
|
514 |
return TokenClassifierOutput(
|
515 |
loss=loss,
|
516 |
logits=logits,
|
517 |
-
hidden_states=contextualized_embeddings,
|
518 |
-
attentions=attention_probs
|
519 |
)
|
520 |
|
521 |
|
@@ -569,15 +588,20 @@ class NorbertForQuestionAnswering(NorbertModel):
|
|
569 |
total_loss = (start_loss + end_loss) / 2
|
570 |
|
571 |
if not return_dict:
|
572 |
-
output =
|
|
|
|
|
|
|
|
|
|
|
573 |
return ((total_loss,) + output) if total_loss is not None else output
|
574 |
|
575 |
return QuestionAnsweringModelOutput(
|
576 |
loss=total_loss,
|
577 |
start_logits=start_logits,
|
578 |
end_logits=end_logits,
|
579 |
-
hidden_states=contextualized_embeddings,
|
580 |
-
attentions=attention_probs
|
581 |
)
|
582 |
|
583 |
|
@@ -598,9 +622,9 @@ class NorbertForMultipleChoice(NorbertModel):
|
|
598 |
token_type_ids: Optional[torch.Tensor] = None,
|
599 |
position_ids: Optional[torch.Tensor] = None,
|
600 |
labels: Optional[torch.Tensor] = None,
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
605 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
606 |
num_choices = input_ids.shape[1]
|
@@ -618,12 +642,16 @@ class NorbertForMultipleChoice(NorbertModel):
|
|
618 |
loss = loss_fct(reshaped_logits, labels)
|
619 |
|
620 |
if not return_dict:
|
621 |
-
output = (
|
|
|
|
|
|
|
|
|
622 |
return ((loss,) + output) if loss is not None else output
|
623 |
|
624 |
return MultipleChoiceModelOutput(
|
625 |
loss=loss,
|
626 |
logits=reshaped_logits,
|
627 |
-
hidden_states=contextualized_embeddings,
|
628 |
-
attentions=attention_probs
|
629 |
)
|
|
|
|
|
|
|
1 |
import math
|
2 |
from typing import List, Optional, Tuple, Union
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
|
|
7 |
from torch.utils import checkpoint
|
8 |
|
9 |
from configuration_norbert import NorbertConfig
|
|
|
17 |
TokenClassifierOutput,
|
18 |
BaseModelOutput
|
19 |
)
|
20 |
+
from transformers.pytorch_utils import softmax_backward_data
|
21 |
|
22 |
|
23 |
class Encoder(nn.Module):
|
|
|
128 |
@staticmethod
|
129 |
def backward(self, grad_output):
|
130 |
output, = self.saved_tensors
|
131 |
+
input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
|
132 |
+
return input_grad, None, None
|
133 |
|
134 |
|
135 |
class Attention(nn.Module):
|
|
|
186 |
if self.position_indices.size(0) < query_len:
|
187 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
188 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
189 |
+
position_indices = self.make_log_bucket_position(position_indices, self.position_bucket_size, 512)
|
190 |
+
position_indices = self.position_bucket_size - 1 + position_indices
|
191 |
+
self.position_indices = position_indices.to(hidden_states.device)
|
192 |
|
193 |
hidden_states = self.pre_layer_norm(hidden_states)
|
194 |
|
195 |
query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
|
196 |
value = self.in_proj_v(hidden_states) # shape: [T, B, D]
|
197 |
|
|
|
|
|
|
|
|
|
|
|
198 |
query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
199 |
key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
200 |
value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
201 |
|
202 |
attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
|
203 |
|
204 |
+
pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
|
205 |
+
query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2)
|
206 |
query = query.view(batch_size, self.num_heads, query_len, self.head_size)
|
207 |
key = key.view(batch_size, self.num_heads, query_len, self.head_size)
|
208 |
+
|
209 |
+
attention_c_p = torch.einsum("bhqd,khd->bhqk", query, key_pos.squeeze(1) * self.scale)
|
210 |
+
attention_p_c = torch.einsum("bhkd,qhd->bhqk", key * self.scale, query_pos.squeeze(1))
|
211 |
+
|
212 |
+
position_indices = self.position_indices[:query_len, :key_len].expand(batch_size, self.num_heads, -1, -1)
|
213 |
+
attention_c_p = attention_c_p.gather(3, position_indices)
|
214 |
+
attention_p_c = attention_p_c.gather(2, position_indices)
|
215 |
+
|
216 |
attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
|
217 |
+
attention_scores.add_(attention_c_p)
|
218 |
+
attention_scores.add_(attention_p_c)
|
219 |
|
220 |
return attention_scores, value
|
221 |
|
|
|
335 |
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
|
336 |
|
337 |
if not return_dict:
|
338 |
+
return (
|
339 |
+
sequence_output,
|
340 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
341 |
+
*([attention_probs] if output_attentions else [])
|
342 |
+
)
|
343 |
|
344 |
return BaseModelOutput(
|
345 |
last_hidden_state=sequence_output,
|
346 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
347 |
+
attentions=attention_probs if output_attentions else None
|
348 |
)
|
349 |
|
350 |
|
|
|
382 |
masked_lm_loss = F.cross_entropy(subword_prediction.flatten(0, 1), labels.flatten())
|
383 |
|
384 |
if not return_dict:
|
385 |
+
output = (
|
386 |
+
subword_prediction,
|
387 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
388 |
+
*([attention_probs] if output_attentions else [])
|
389 |
+
)
|
390 |
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
391 |
|
392 |
return MaskedLMOutput(
|
393 |
loss=masked_lm_loss,
|
394 |
logits=subword_prediction,
|
395 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
396 |
+
attentions=attention_probs if output_attentions else None
|
397 |
)
|
398 |
|
399 |
|
|
|
476 |
loss = loss_fct(logits, labels)
|
477 |
|
478 |
if not return_dict:
|
479 |
+
output = (
|
480 |
+
logits,
|
481 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
482 |
+
*([attention_probs] if output_attentions else [])
|
483 |
+
)
|
484 |
return ((loss,) + output) if loss is not None else output
|
485 |
|
486 |
return SequenceClassifierOutput(
|
487 |
loss=loss,
|
488 |
logits=logits,
|
489 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
490 |
+
attentions=attention_probs if output_attentions else None
|
491 |
)
|
492 |
|
493 |
|
|
|
523 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
524 |
|
525 |
if not return_dict:
|
526 |
+
output = (
|
527 |
+
logits,
|
528 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
529 |
+
*([attention_probs] if output_attentions else [])
|
530 |
+
)
|
531 |
return ((loss,) + output) if loss is not None else output
|
532 |
|
533 |
return TokenClassifierOutput(
|
534 |
loss=loss,
|
535 |
logits=logits,
|
536 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
537 |
+
attentions=attention_probs if output_attentions else None
|
538 |
)
|
539 |
|
540 |
|
|
|
588 |
total_loss = (start_loss + end_loss) / 2
|
589 |
|
590 |
if not return_dict:
|
591 |
+
output = (
|
592 |
+
start_logits,
|
593 |
+
end_logits,
|
594 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
595 |
+
*([attention_probs] if output_attentions else [])
|
596 |
+
)
|
597 |
return ((total_loss,) + output) if total_loss is not None else output
|
598 |
|
599 |
return QuestionAnsweringModelOutput(
|
600 |
loss=total_loss,
|
601 |
start_logits=start_logits,
|
602 |
end_logits=end_logits,
|
603 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
604 |
+
attentions=attention_probs if output_attentions else None
|
605 |
)
|
606 |
|
607 |
|
|
|
622 |
token_type_ids: Optional[torch.Tensor] = None,
|
623 |
position_ids: Optional[torch.Tensor] = None,
|
624 |
labels: Optional[torch.Tensor] = None,
|
625 |
+
output_attentions: Optional[bool] = None,
|
626 |
+
output_hidden_states: Optional[bool] = None,
|
627 |
+
return_dict: Optional[bool] = None
|
628 |
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
629 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
630 |
num_choices = input_ids.shape[1]
|
|
|
642 |
loss = loss_fct(reshaped_logits, labels)
|
643 |
|
644 |
if not return_dict:
|
645 |
+
output = (
|
646 |
+
reshaped_logits,
|
647 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
648 |
+
*([attention_probs] if output_attentions else [])
|
649 |
+
)
|
650 |
return ((loss,) + output) if loss is not None else output
|
651 |
|
652 |
return MultipleChoiceModelOutput(
|
653 |
loss=loss,
|
654 |
logits=reshaped_logits,
|
655 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
656 |
+
attentions=attention_probs if output_attentions else None
|
657 |
)
|