Commit
·
24ea396
1
Parent(s):
6db9dbd
updated saf tensors so they match
Browse files- modeling_lsg_bert.py +80 -4
modeling_lsg_bert.py
CHANGED
@@ -23,7 +23,11 @@ from transformers.models.bert.modeling_bert import (
|
|
23 |
)
|
24 |
import torch
|
25 |
import torch.nn as nn
|
26 |
-
from transformers.modeling_outputs import
|
|
|
|
|
|
|
|
|
27 |
from transformers.models.bert.configuration_bert import BertConfig
|
28 |
import math
|
29 |
import sys
|
@@ -817,7 +821,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
817 |
n, h, t, d = query_layer.size()
|
818 |
|
819 |
# Cat global mask
|
820 |
-
attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
|
821 |
|
822 |
# Use normal attention if local attention covers every tokens
|
823 |
if t <= 2 * self.block_size + self.num_global_tokens:
|
@@ -1023,9 +1027,9 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
1023 |
|
1024 |
def __init__(self, config, add_pooling_layer=True):
|
1025 |
|
1026 |
-
#
|
1027 |
BertModel.__init__(self, config)
|
1028 |
-
|
1029 |
LSGBertPreTrainedModel.__init__(self, config)
|
1030 |
|
1031 |
self.config = config
|
@@ -1041,6 +1045,78 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
1041 |
|
1042 |
# Initialize weights and apply final processing
|
1043 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1044 |
|
1045 |
def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
|
1046 |
|
|
|
23 |
)
|
24 |
import torch
|
25 |
import torch.nn as nn
|
26 |
+
from transformers.modeling_outputs import (
|
27 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
28 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
29 |
+
)
|
30 |
+
|
31 |
from transformers.models.bert.configuration_bert import BertConfig
|
32 |
import math
|
33 |
import sys
|
|
|
821 |
n, h, t, d = query_layer.size()
|
822 |
|
823 |
# Cat global mask
|
824 |
+
# attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
|
825 |
|
826 |
# Use normal attention if local attention covers every tokens
|
827 |
if t <= 2 * self.block_size + self.num_global_tokens:
|
|
|
1027 |
|
1028 |
def __init__(self, config, add_pooling_layer=True):
|
1029 |
|
1030 |
+
# 1) Initialize the standard BertModel
|
1031 |
BertModel.__init__(self, config)
|
1032 |
+
# 2) Initialize our LSG PreTrained
|
1033 |
LSGBertPreTrainedModel.__init__(self, config)
|
1034 |
|
1035 |
self.config = config
|
|
|
1045 |
|
1046 |
# Initialize weights and apply final processing
|
1047 |
self.post_init()
|
1048 |
+
|
1049 |
+
def forward(
|
1050 |
+
self,
|
1051 |
+
input_ids=None,
|
1052 |
+
attention_mask=None,
|
1053 |
+
token_type_ids=None,
|
1054 |
+
position_ids=None,
|
1055 |
+
head_mask=None,
|
1056 |
+
inputs_embeds=None,
|
1057 |
+
encoder_hidden_states=None,
|
1058 |
+
encoder_attention_mask=None,
|
1059 |
+
past_key_values=None,
|
1060 |
+
use_cache=None,
|
1061 |
+
output_attentions=None,
|
1062 |
+
output_hidden_states=None,
|
1063 |
+
return_dict=None,
|
1064 |
+
):
|
1065 |
+
# ----------------------------
|
1066 |
+
# 1) Use LSG embeddings
|
1067 |
+
embedding_output = self.embeddings(
|
1068 |
+
input_ids=input_ids,
|
1069 |
+
token_type_ids=token_type_ids,
|
1070 |
+
position_ids=position_ids,
|
1071 |
+
inputs_embeds=inputs_embeds,
|
1072 |
+
past_key_values_length=past_key_values[0][0].size(2) if past_key_values else 0
|
1073 |
+
) if (input_ids is not None or inputs_embeds is not None) else None
|
1074 |
+
|
1075 |
+
# 2) If we have an attention mask and some global tokens, pad the mask
|
1076 |
+
# by `config.num_global_tokens` so it matches embedding_output.size(1).
|
1077 |
+
if attention_mask is not None and self.config.num_global_tokens > 0:
|
1078 |
+
# Original shape: (batch_size, seq_len)
|
1079 |
+
bsz, seq_len = attention_mask.shape
|
1080 |
+
|
1081 |
+
new_shape = (bsz, seq_len + self.config.num_global_tokens)
|
1082 |
+
extended_mask = torch.zeros(new_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
1083 |
+
# Fill from index `num_global_tokens` onward
|
1084 |
+
extended_mask[:, self.config.num_global_tokens:] = attention_mask
|
1085 |
+
attention_mask = extended_mask
|
1086 |
+
|
1087 |
+
# 3) Now call self.encoder with the updated mask
|
1088 |
+
encoder_outputs = self.encoder(
|
1089 |
+
hidden_states=embedding_output,
|
1090 |
+
attention_mask=attention_mask.unsqueeze(1).unsqueeze(2) if attention_mask is not None else None,
|
1091 |
+
head_mask=head_mask,
|
1092 |
+
encoder_hidden_states=encoder_hidden_states,
|
1093 |
+
encoder_attention_mask=encoder_attention_mask,
|
1094 |
+
past_key_values=past_key_values,
|
1095 |
+
use_cache=use_cache,
|
1096 |
+
output_attentions=output_attentions,
|
1097 |
+
output_hidden_states=output_hidden_states,
|
1098 |
+
return_dict=return_dict
|
1099 |
+
)
|
1100 |
+
|
1101 |
+
# 4) Grab the last hidden state
|
1102 |
+
sequence_output = encoder_outputs[0]
|
1103 |
+
|
1104 |
+
# 5) Optionally apply the pooler
|
1105 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
1106 |
+
|
1107 |
+
# Return
|
1108 |
+
if not return_dict:
|
1109 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
1110 |
+
|
1111 |
+
|
1112 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
1113 |
+
last_hidden_state=sequence_output,
|
1114 |
+
pooler_output=pooled_output,
|
1115 |
+
past_key_values=encoder_outputs.past_key_values,
|
1116 |
+
hidden_states=encoder_outputs.hidden_states,
|
1117 |
+
attentions=encoder_outputs.attentions,
|
1118 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
1119 |
+
)
|
1120 |
|
1121 |
def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
|
1122 |
|