small fix with torch.finfo
Browse files- modeling_lsg_bert.py +88 -193
modeling_lsg_bert.py
CHANGED
@@ -53,10 +53,11 @@ class LSGBertConfig(BertConfig):
|
|
53 |
self.sparse_block_size = sparse_block_size
|
54 |
self.sparsity_factor = sparsity_factor
|
55 |
self.sparsity_type = sparsity_type
|
56 |
-
|
57 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
58 |
logger.warning(
|
59 |
-
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'],
|
|
|
60 |
self.sparsity_type = None
|
61 |
|
62 |
if self.sparsity_type in ["stride", "block_stride"]:
|
@@ -64,7 +65,7 @@ class LSGBertConfig(BertConfig):
|
|
64 |
logger.warning(
|
65 |
"[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
|
66 |
)
|
67 |
-
|
68 |
if self.num_global_tokens < 1:
|
69 |
logger.warning(
|
70 |
"[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
|
@@ -72,13 +73,23 @@ class LSGBertConfig(BertConfig):
|
|
72 |
self.num_global_tokens = 1
|
73 |
elif self.num_global_tokens > 512:
|
74 |
logger.warning(
|
75 |
-
"[WARNING CONFIG]: num_global_tokens > 512 is not
|
76 |
)
|
77 |
self.num_global_tokens = 512
|
78 |
|
79 |
if self.sparsity_factor > 0:
|
80 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
81 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
|
84 |
class BaseSelfAttention(nn.Module):
|
@@ -188,7 +199,7 @@ class CausalAttentionProduct(nn.Module):
|
|
188 |
diagonal=-1
|
189 |
)
|
190 |
causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
|
191 |
-
attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
|
192 |
|
193 |
del attention_mask
|
194 |
|
@@ -529,7 +540,8 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
529 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
530 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
531 |
|
532 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
533 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
534 |
|
535 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
@@ -594,7 +606,8 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
594 |
keys /= mask + 1e-8
|
595 |
values /= mask + 1e-8
|
596 |
|
597 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
598 |
|
599 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
600 |
|
@@ -695,8 +708,6 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
695 |
output_attentions=output_attentions
|
696 |
)
|
697 |
|
698 |
-
#if head_mask is not None:
|
699 |
-
# outputs = (outputs[0] * head_mask[:, :, :1, :1], ) + outputs[1:]
|
700 |
return outputs
|
701 |
|
702 |
def causal_forward(
|
@@ -862,12 +873,6 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
862 |
return x.reshape(n, h, -1, chunk_size, d)
|
863 |
|
864 |
|
865 |
-
class LSGBertSelfOutput(BertSelfOutput):
|
866 |
-
|
867 |
-
def __init__(self, config):
|
868 |
-
super().__init__(config)
|
869 |
-
|
870 |
-
|
871 |
class LSGAttention(BertAttention):
|
872 |
|
873 |
def __init__(self, config):
|
@@ -875,107 +880,97 @@ class LSGAttention(BertAttention):
|
|
875 |
nn.Module.__init__(self)
|
876 |
|
877 |
self.self = LSGSelfAttention(config)
|
878 |
-
self.output =
|
879 |
self.pruned_heads = set()
|
880 |
|
881 |
|
882 |
-
class LSGBertIntermediate(BertIntermediate):
|
883 |
-
|
884 |
-
def __init__(self, config):
|
885 |
-
super().__init__(config)
|
886 |
-
|
887 |
-
|
888 |
-
class LSGBertOutput(BertOutput):
|
889 |
-
|
890 |
-
def __init__(self, config):
|
891 |
-
super().__init__(config)
|
892 |
-
|
893 |
-
|
894 |
class LSGBertLayer(BertLayer):
|
895 |
|
896 |
def __init__(self, config):
|
897 |
|
898 |
-
|
899 |
|
900 |
-
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
901 |
-
self.seq_len_dim = 1
|
902 |
self.attention = LSGAttention(config)
|
903 |
-
self.is_decoder = config.is_decoder
|
904 |
-
self.add_cross_attention = config.add_cross_attention
|
905 |
if self.add_cross_attention:
|
906 |
if not self.is_decoder:
|
907 |
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
|
908 |
self.crossattention = LSGAttention(config)
|
909 |
-
self.intermediate = LSGBertIntermediate(config)
|
910 |
-
self.output = LSGBertOutput(config)
|
911 |
|
912 |
|
913 |
class LSGBertEncoder(BertEncoder):
|
914 |
|
915 |
def __init__(self, config):
|
916 |
|
917 |
-
nn.Module.__init__(self)
|
918 |
-
|
919 |
-
self.config = config
|
920 |
-
self.layer = nn.ModuleList([LSGBertLayer(config) for _ in range(config.num_hidden_layers)])
|
921 |
-
self.gradient_checkpointing = False
|
922 |
-
|
923 |
-
|
924 |
-
class LSGBertPooler(BertPooler):
|
925 |
-
|
926 |
-
def __init__(self, config):
|
927 |
-
super().__init__(config)
|
928 |
-
|
929 |
-
|
930 |
-
class LSGBertPredictionHeadTransform(BertPredictionHeadTransform):
|
931 |
-
|
932 |
-
def __init__(self, config):
|
933 |
super().__init__(config)
|
934 |
|
|
|
935 |
|
936 |
-
|
|
|
|
|
937 |
|
938 |
-
|
|
|
|
|
|
|
|
|
939 |
|
940 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
941 |
|
942 |
-
self.
|
943 |
-
|
944 |
-
# The output weights are the same as the input embeddings, but there is
|
945 |
-
# an output-only bias for each token.
|
946 |
-
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
947 |
-
|
948 |
-
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
949 |
-
|
950 |
-
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
951 |
-
self.decoder.bias = self.bias
|
952 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
953 |
|
954 |
-
|
955 |
-
|
956 |
|
957 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
958 |
|
959 |
-
|
960 |
-
|
961 |
-
|
962 |
-
|
963 |
-
|
964 |
-
class LSGBertOnlyNSPHead(BertOnlyNSPHead):
|
965 |
-
|
966 |
-
def __init__(self, config):
|
967 |
-
super().__init__(config)
|
968 |
-
|
969 |
|
970 |
-
|
|
|
971 |
|
972 |
-
|
|
|
973 |
|
974 |
-
|
975 |
-
|
976 |
-
self.predictions = BertLMPredictionHead(config)
|
977 |
-
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
978 |
-
|
979 |
|
980 |
class LSGBertPreTrainedModel(BertPreTrainedModel):
|
981 |
"""
|
@@ -1003,19 +998,10 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
1003 |
LSGBertPreTrainedModel.__init__(self, config)
|
1004 |
|
1005 |
self.config = config
|
1006 |
-
assert hasattr(config, "num_global_tokens")
|
1007 |
-
self.num_global_tokens = config.num_global_tokens
|
1008 |
-
self.pad_idx = config.pad_token_id
|
1009 |
-
|
1010 |
-
assert hasattr(config, "block_size") and hasattr(config, "adaptive")
|
1011 |
-
self.block_size = config.block_size
|
1012 |
-
self.adaptive = config.adaptive
|
1013 |
-
self.mask_first_token = config.mask_first_token
|
1014 |
-
self.pool_with_global = config.pool_with_global
|
1015 |
|
1016 |
self.embeddings = LSGBertEmbeddings(config)
|
1017 |
self.encoder = LSGBertEncoder(config)
|
1018 |
-
self.pooler =
|
1019 |
|
1020 |
if config.add_cross_attention:
|
1021 |
logger.warning(
|
@@ -1025,97 +1011,6 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
1025 |
# Initialize weights and apply final processing
|
1026 |
self.post_init()
|
1027 |
|
1028 |
-
def forward(
|
1029 |
-
self,
|
1030 |
-
input_ids=None,
|
1031 |
-
attention_mask=None,
|
1032 |
-
token_type_ids=None,
|
1033 |
-
position_ids=None,
|
1034 |
-
head_mask=None,
|
1035 |
-
inputs_embeds=None,
|
1036 |
-
encoder_hidden_states=None,
|
1037 |
-
encoder_attention_mask=None,
|
1038 |
-
past_key_values=None,
|
1039 |
-
use_cache=None,
|
1040 |
-
output_attentions=None,
|
1041 |
-
output_hidden_states=None,
|
1042 |
-
return_dict=None
|
1043 |
-
):
|
1044 |
-
|
1045 |
-
inputs_ = input_ids if input_ids is not None else inputs_embeds
|
1046 |
-
n, t = inputs_.size()[:2]
|
1047 |
-
|
1048 |
-
if attention_mask is None:
|
1049 |
-
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
1050 |
-
if self.mask_first_token:
|
1051 |
-
attention_mask[:,0] = 0
|
1052 |
-
if token_type_ids is None:
|
1053 |
-
token_type_ids = torch.zeros(n, t, device=inputs_.device).long()
|
1054 |
-
|
1055 |
-
b = self.block_size * 2
|
1056 |
-
pad = t % self.block_size
|
1057 |
-
|
1058 |
-
# Check if t is multiple of block_size and pad
|
1059 |
-
if self.adaptive and t > b and pad > 0:
|
1060 |
-
pad_length = self.block_size - pad
|
1061 |
-
if input_ids is not None:
|
1062 |
-
input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
|
1063 |
-
else:
|
1064 |
-
inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
|
1065 |
-
|
1066 |
-
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
|
1067 |
-
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
|
1068 |
-
|
1069 |
-
if position_ids is not None:
|
1070 |
-
position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
|
1071 |
-
|
1072 |
-
n, t_ = attention_mask.size()
|
1073 |
-
|
1074 |
-
encoder_outputs = super().forward(
|
1075 |
-
input_ids=input_ids,
|
1076 |
-
attention_mask=attention_mask,
|
1077 |
-
token_type_ids=token_type_ids,
|
1078 |
-
position_ids=position_ids,
|
1079 |
-
head_mask=head_mask,
|
1080 |
-
inputs_embeds=inputs_embeds,
|
1081 |
-
encoder_hidden_states=encoder_hidden_states,
|
1082 |
-
encoder_attention_mask=encoder_attention_mask,
|
1083 |
-
past_key_values=past_key_values,
|
1084 |
-
use_cache=use_cache,
|
1085 |
-
output_attentions=output_attentions,
|
1086 |
-
output_hidden_states=output_hidden_states,
|
1087 |
-
return_dict=return_dict
|
1088 |
-
)
|
1089 |
-
|
1090 |
-
context = encoder_outputs[0]
|
1091 |
-
if self.pool_with_global:
|
1092 |
-
context[:, self.num_global_tokens] = context[:, 0]
|
1093 |
-
|
1094 |
-
diff = t - t_
|
1095 |
-
n, _, d = context.size()
|
1096 |
-
context = context[..., self.num_global_tokens:, :]
|
1097 |
-
|
1098 |
-
# Adapt sequence to initial shape
|
1099 |
-
if diff < 0:
|
1100 |
-
context = context[:, :t]
|
1101 |
-
|
1102 |
-
encoder_outputs.last_hidden_state = context
|
1103 |
-
|
1104 |
-
sequence_output = encoder_outputs[0]
|
1105 |
-
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
1106 |
-
|
1107 |
-
if not return_dict:
|
1108 |
-
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
1109 |
-
|
1110 |
-
return BaseModelOutputWithPoolingAndCrossAttentions(
|
1111 |
-
last_hidden_state=sequence_output,
|
1112 |
-
pooler_output=pooled_output,
|
1113 |
-
past_key_values=encoder_outputs.past_key_values,
|
1114 |
-
hidden_states=encoder_outputs.hidden_states,
|
1115 |
-
attentions=encoder_outputs.attentions,
|
1116 |
-
cross_attentions=encoder_outputs.cross_attentions,
|
1117 |
-
)
|
1118 |
-
|
1119 |
def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
|
1120 |
|
1121 |
# Do not rely on original triangular mask from BERT/RoBERTa for causalLM
|
@@ -1134,33 +1029,33 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
1134 |
return extended_attention_mask
|
1135 |
|
1136 |
|
1137 |
-
class LSGBertForPreTraining(LSGBertPreTrainedModel):
|
1138 |
|
1139 |
def __init__(self, config):
|
1140 |
|
1141 |
-
|
1142 |
|
1143 |
self.bert = LSGBertModel(config)
|
1144 |
-
self.cls =
|
1145 |
|
1146 |
# Initialize weights and apply final processing
|
1147 |
self.post_init()
|
1148 |
|
1149 |
|
1150 |
-
class LSGBertLMHeadModel(BertLMHeadModel):
|
1151 |
|
1152 |
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1153 |
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
1154 |
|
1155 |
def __init__(self, config):
|
1156 |
|
1157 |
-
|
1158 |
|
1159 |
if not config.is_decoder:
|
1160 |
logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
|
1161 |
|
1162 |
self.bert = LSGBertModel(config, add_pooling_layer=False)
|
1163 |
-
self.cls =
|
1164 |
|
1165 |
# Initialize weights and apply final processing
|
1166 |
self.post_init()
|
@@ -1187,7 +1082,7 @@ class LSGBertForMaskedLM(LSGBertPreTrainedModel, BertForMaskedLM):
|
|
1187 |
)
|
1188 |
|
1189 |
self.bert = LSGBertModel(config, add_pooling_layer=False)
|
1190 |
-
self.cls =
|
1191 |
|
1192 |
# Initialize weights and apply final processing
|
1193 |
self.post_init()
|
@@ -1200,7 +1095,7 @@ class LSGBertForNextSentencePrediction(LSGBertPreTrainedModel, BertForNextSenten
|
|
1200 |
LSGBertPreTrainedModel.__init__(self, config)
|
1201 |
|
1202 |
self.bert = LSGBertModel(config)
|
1203 |
-
self.cls =
|
1204 |
|
1205 |
# Initialize weights and apply final processing
|
1206 |
self.post_init()
|
|
|
53 |
self.sparse_block_size = sparse_block_size
|
54 |
self.sparsity_factor = sparsity_factor
|
55 |
self.sparsity_type = sparsity_type
|
56 |
+
|
57 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
58 |
logger.warning(
|
59 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
|
60 |
+
setting sparsity_type=None, computation will skip sparse attention")
|
61 |
self.sparsity_type = None
|
62 |
|
63 |
if self.sparsity_type in ["stride", "block_stride"]:
|
|
|
65 |
logger.warning(
|
66 |
"[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
|
67 |
)
|
68 |
+
|
69 |
if self.num_global_tokens < 1:
|
70 |
logger.warning(
|
71 |
"[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
|
|
|
73 |
self.num_global_tokens = 1
|
74 |
elif self.num_global_tokens > 512:
|
75 |
logger.warning(
|
76 |
+
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
|
77 |
)
|
78 |
self.num_global_tokens = 512
|
79 |
|
80 |
if self.sparsity_factor > 0:
|
81 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
82 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
83 |
+
|
84 |
+
if self.mask_first_token and not pool_with_global:
|
85 |
+
logger.warning(
|
86 |
+
"[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
|
87 |
+
self.pool_with_global = True
|
88 |
+
|
89 |
+
if hasattr(self, "position_embedding_type"):
|
90 |
+
if self.position_embedding_type != "absolute":
|
91 |
+
logger.warning(
|
92 |
+
"[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
|
93 |
|
94 |
|
95 |
class BaseSelfAttention(nn.Module):
|
|
|
199 |
diagonal=-1
|
200 |
)
|
201 |
causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
|
202 |
+
attention_scores[..., -causal_shape[0]:, -causal_shape[1] + 1:] = causal_mask[:, 1:]
|
203 |
|
204 |
del attention_mask
|
205 |
|
|
|
540 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
541 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
542 |
|
543 |
+
mask = (1. - mask.clamp(0, 1))
|
544 |
+
mask *= torch.finfo(mask.dtype).min
|
545 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
546 |
|
547 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
606 |
keys /= mask + 1e-8
|
607 |
values /= mask + 1e-8
|
608 |
|
609 |
+
mask = (1. - mask.clamp(0, 1))
|
610 |
+
mask *= torch.finfo(mask.dtype).min
|
611 |
|
612 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
613 |
|
|
|
708 |
output_attentions=output_attentions
|
709 |
)
|
710 |
|
|
|
|
|
711 |
return outputs
|
712 |
|
713 |
def causal_forward(
|
|
|
873 |
return x.reshape(n, h, -1, chunk_size, d)
|
874 |
|
875 |
|
|
|
|
|
|
|
|
|
|
|
|
|
876 |
class LSGAttention(BertAttention):
|
877 |
|
878 |
def __init__(self, config):
|
|
|
880 |
nn.Module.__init__(self)
|
881 |
|
882 |
self.self = LSGSelfAttention(config)
|
883 |
+
self.output = BertSelfOutput(config)
|
884 |
self.pruned_heads = set()
|
885 |
|
886 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
887 |
class LSGBertLayer(BertLayer):
|
888 |
|
889 |
def __init__(self, config):
|
890 |
|
891 |
+
super().__init__(config)
|
892 |
|
|
|
|
|
893 |
self.attention = LSGAttention(config)
|
|
|
|
|
894 |
if self.add_cross_attention:
|
895 |
if not self.is_decoder:
|
896 |
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
|
897 |
self.crossattention = LSGAttention(config)
|
|
|
|
|
898 |
|
899 |
|
900 |
class LSGBertEncoder(BertEncoder):
|
901 |
|
902 |
def __init__(self, config):
|
903 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
904 |
super().__init__(config)
|
905 |
|
906 |
+
self.layer = nn.ModuleList([LSGBertLayer(config) for _ in range(config.num_hidden_layers)])
|
907 |
|
908 |
+
assert hasattr(config, "num_global_tokens")
|
909 |
+
self.num_global_tokens = config.num_global_tokens
|
910 |
+
self.pad_idx = config.pad_token_id
|
911 |
|
912 |
+
assert hasattr(config, "block_size") and hasattr(config, "adaptive")
|
913 |
+
self.block_size = config.block_size
|
914 |
+
self.adaptive = config.adaptive
|
915 |
+
self.mask_first_token = config.mask_first_token
|
916 |
+
self.pool_with_global = config.pool_with_global
|
917 |
|
918 |
+
def forward(
|
919 |
+
self,
|
920 |
+
hidden_states: torch.Tensor,
|
921 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
922 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
923 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
924 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
925 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
926 |
+
use_cache: Optional[bool] = None,
|
927 |
+
output_attentions: Optional[bool] = False,
|
928 |
+
output_hidden_states: Optional[bool] = False,
|
929 |
+
return_dict: Optional[bool] = True,
|
930 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
931 |
+
|
932 |
+
mask_value = torch.finfo(attention_mask.dtype).min
|
933 |
+
n, _, __, t = attention_mask.size()
|
934 |
|
935 |
+
if not (self.config.is_decoder and encoder_hidden_states is not None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
936 |
|
937 |
+
b = self.block_size * 2
|
938 |
+
pad = t % self.block_size
|
939 |
+
|
940 |
+
# Check if t is multiple of block_size and pad
|
941 |
+
if self.adaptive and t > b and pad > 0:
|
942 |
+
pad_length = self.block_size - pad
|
943 |
+
hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
|
944 |
+
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=mask_value)
|
945 |
|
946 |
+
if self.mask_first_token:
|
947 |
+
attention_mask[..., 0] = mask_value
|
948 |
|
949 |
+
encoder_outputs = super().forward(
|
950 |
+
hidden_states=hidden_states,
|
951 |
+
attention_mask=attention_mask,
|
952 |
+
head_mask=head_mask,
|
953 |
+
encoder_hidden_states=encoder_hidden_states,
|
954 |
+
encoder_attention_mask=encoder_attention_mask,
|
955 |
+
past_key_values=past_key_values,
|
956 |
+
use_cache=use_cache,
|
957 |
+
output_attentions=output_attentions,
|
958 |
+
output_hidden_states=output_hidden_states,
|
959 |
+
return_dict=return_dict
|
960 |
+
)
|
961 |
|
962 |
+
sequence_output = encoder_outputs[0]
|
963 |
+
if self.pool_with_global:
|
964 |
+
sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
965 |
|
966 |
+
# Adapt sequence to initial shape
|
967 |
+
sequence_output = sequence_output[..., self.num_global_tokens: t + self.num_global_tokens, :]
|
968 |
|
969 |
+
if not return_dict:
|
970 |
+
return (sequence_output, ) + encoder_outputs[1:]
|
971 |
|
972 |
+
encoder_outputs.last_hidden_state = sequence_output
|
973 |
+
return encoder_outputs
|
|
|
|
|
|
|
974 |
|
975 |
class LSGBertPreTrainedModel(BertPreTrainedModel):
|
976 |
"""
|
|
|
998 |
LSGBertPreTrainedModel.__init__(self, config)
|
999 |
|
1000 |
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1001 |
|
1002 |
self.embeddings = LSGBertEmbeddings(config)
|
1003 |
self.encoder = LSGBertEncoder(config)
|
1004 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
1005 |
|
1006 |
if config.add_cross_attention:
|
1007 |
logger.warning(
|
|
|
1011 |
# Initialize weights and apply final processing
|
1012 |
self.post_init()
|
1013 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1014 |
def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
|
1015 |
|
1016 |
# Do not rely on original triangular mask from BERT/RoBERTa for causalLM
|
|
|
1029 |
return extended_attention_mask
|
1030 |
|
1031 |
|
1032 |
+
class LSGBertForPreTraining(LSGBertPreTrainedModel, BertForPreTraining):
|
1033 |
|
1034 |
def __init__(self, config):
|
1035 |
|
1036 |
+
LSGBertPreTrainedModel.__init__(self, config)
|
1037 |
|
1038 |
self.bert = LSGBertModel(config)
|
1039 |
+
self.cls = BertPreTrainingHeads(config)
|
1040 |
|
1041 |
# Initialize weights and apply final processing
|
1042 |
self.post_init()
|
1043 |
|
1044 |
|
1045 |
+
class LSGBertLMHeadModel(LSGBertPreTrainedModel, BertLMHeadModel):
|
1046 |
|
1047 |
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1048 |
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
1049 |
|
1050 |
def __init__(self, config):
|
1051 |
|
1052 |
+
LSGBertPreTrainedModel.__init__(self, config)
|
1053 |
|
1054 |
if not config.is_decoder:
|
1055 |
logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
|
1056 |
|
1057 |
self.bert = LSGBertModel(config, add_pooling_layer=False)
|
1058 |
+
self.cls = BertOnlyMLMHead(config)
|
1059 |
|
1060 |
# Initialize weights and apply final processing
|
1061 |
self.post_init()
|
|
|
1082 |
)
|
1083 |
|
1084 |
self.bert = LSGBertModel(config, add_pooling_layer=False)
|
1085 |
+
self.cls = BertOnlyMLMHead(config)
|
1086 |
|
1087 |
# Initialize weights and apply final processing
|
1088 |
self.post_init()
|
|
|
1095 |
LSGBertPreTrainedModel.__init__(self, config)
|
1096 |
|
1097 |
self.bert = LSGBertModel(config)
|
1098 |
+
self.cls = BertOnlyNSPHead(config)
|
1099 |
|
1100 |
# Initialize weights and apply final processing
|
1101 |
self.post_init()
|