ccdv commited on
Commit
a68bf66
·
1 Parent(s): e2bad81

small fix with torch.finfo

Browse files
Files changed (1) hide show
  1. modeling_lsg_bert.py +88 -193
modeling_lsg_bert.py CHANGED
@@ -53,10 +53,11 @@ class LSGBertConfig(BertConfig):
53
  self.sparse_block_size = sparse_block_size
54
  self.sparsity_factor = sparsity_factor
55
  self.sparsity_type = sparsity_type
56
-
57
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
58
  logger.warning(
59
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
 
60
  self.sparsity_type = None
61
 
62
  if self.sparsity_type in ["stride", "block_stride"]:
@@ -64,7 +65,7 @@ class LSGBertConfig(BertConfig):
64
  logger.warning(
65
  "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
66
  )
67
-
68
  if self.num_global_tokens < 1:
69
  logger.warning(
70
  "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
@@ -72,13 +73,23 @@ class LSGBertConfig(BertConfig):
72
  self.num_global_tokens = 1
73
  elif self.num_global_tokens > 512:
74
  logger.warning(
75
- "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
76
  )
77
  self.num_global_tokens = 512
78
 
79
  if self.sparsity_factor > 0:
80
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
81
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  class BaseSelfAttention(nn.Module):
@@ -188,7 +199,7 @@ class CausalAttentionProduct(nn.Module):
188
  diagonal=-1
189
  )
190
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
191
- attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
192
 
193
  del attention_mask
194
 
@@ -529,7 +540,8 @@ class LSGSelfAttention(BaseSelfAttention):
529
  keys = keys.sum(dim=-2) / (mask + 1e-6)
530
  values = values.sum(dim=-2) / (mask + 1e-6)
531
 
532
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
533
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
534
 
535
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -594,7 +606,8 @@ class LSGSelfAttention(BaseSelfAttention):
594
  keys /= mask + 1e-8
595
  values /= mask + 1e-8
596
 
597
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
598
 
599
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
600
 
@@ -695,8 +708,6 @@ class LSGSelfAttention(BaseSelfAttention):
695
  output_attentions=output_attentions
696
  )
697
 
698
- #if head_mask is not None:
699
- # outputs = (outputs[0] * head_mask[:, :, :1, :1], ) + outputs[1:]
700
  return outputs
701
 
702
  def causal_forward(
@@ -862,12 +873,6 @@ class LSGSelfAttention(BaseSelfAttention):
862
  return x.reshape(n, h, -1, chunk_size, d)
863
 
864
 
865
- class LSGBertSelfOutput(BertSelfOutput):
866
-
867
- def __init__(self, config):
868
- super().__init__(config)
869
-
870
-
871
  class LSGAttention(BertAttention):
872
 
873
  def __init__(self, config):
@@ -875,107 +880,97 @@ class LSGAttention(BertAttention):
875
  nn.Module.__init__(self)
876
 
877
  self.self = LSGSelfAttention(config)
878
- self.output = LSGBertSelfOutput(config)
879
  self.pruned_heads = set()
880
 
881
 
882
- class LSGBertIntermediate(BertIntermediate):
883
-
884
- def __init__(self, config):
885
- super().__init__(config)
886
-
887
-
888
- class LSGBertOutput(BertOutput):
889
-
890
- def __init__(self, config):
891
- super().__init__(config)
892
-
893
-
894
  class LSGBertLayer(BertLayer):
895
 
896
  def __init__(self, config):
897
 
898
- nn.Module.__init__(self)
899
 
900
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
901
- self.seq_len_dim = 1
902
  self.attention = LSGAttention(config)
903
- self.is_decoder = config.is_decoder
904
- self.add_cross_attention = config.add_cross_attention
905
  if self.add_cross_attention:
906
  if not self.is_decoder:
907
  assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
908
  self.crossattention = LSGAttention(config)
909
- self.intermediate = LSGBertIntermediate(config)
910
- self.output = LSGBertOutput(config)
911
 
912
 
913
  class LSGBertEncoder(BertEncoder):
914
 
915
  def __init__(self, config):
916
 
917
- nn.Module.__init__(self)
918
-
919
- self.config = config
920
- self.layer = nn.ModuleList([LSGBertLayer(config) for _ in range(config.num_hidden_layers)])
921
- self.gradient_checkpointing = False
922
-
923
-
924
- class LSGBertPooler(BertPooler):
925
-
926
- def __init__(self, config):
927
- super().__init__(config)
928
-
929
-
930
- class LSGBertPredictionHeadTransform(BertPredictionHeadTransform):
931
-
932
- def __init__(self, config):
933
  super().__init__(config)
934
 
 
935
 
936
- class LSGBertLMPredictionHead(BertLMPredictionHead):
 
 
937
 
938
- def __init__(self, config):
 
 
 
 
939
 
940
- nn.Module.__init__(self)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941
 
942
- self.transform = LSGBertPredictionHeadTransform(config)
943
-
944
- # The output weights are the same as the input embeddings, but there is
945
- # an output-only bias for each token.
946
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
947
-
948
- self.bias = nn.Parameter(torch.zeros(config.vocab_size))
949
-
950
- # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
951
- self.decoder.bias = self.bias
952
 
 
 
 
 
 
 
 
 
953
 
954
- class LSGBertOnlyMLMHead(BertOnlyMLMHead):
955
- """LSG Head for masked language modeling."""
956
 
957
- def __init__(self, config):
 
 
 
 
 
 
 
 
 
 
 
958
 
959
- nn.Module.__init__(self)
960
-
961
- self.predictions = LSGBertLMPredictionHead(config)
962
-
963
-
964
- class LSGBertOnlyNSPHead(BertOnlyNSPHead):
965
-
966
- def __init__(self, config):
967
- super().__init__(config)
968
-
969
 
970
- class LSGBertPreTrainingHeads(BertPreTrainingHeads):
 
971
 
972
- def __init__(self, config):
 
973
 
974
- nn.Module.__init__(self)
975
-
976
- self.predictions = BertLMPredictionHead(config)
977
- self.seq_relationship = nn.Linear(config.hidden_size, 2)
978
-
979
 
980
  class LSGBertPreTrainedModel(BertPreTrainedModel):
981
  """
@@ -1003,19 +998,10 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
1003
  LSGBertPreTrainedModel.__init__(self, config)
1004
 
1005
  self.config = config
1006
- assert hasattr(config, "num_global_tokens")
1007
- self.num_global_tokens = config.num_global_tokens
1008
- self.pad_idx = config.pad_token_id
1009
-
1010
- assert hasattr(config, "block_size") and hasattr(config, "adaptive")
1011
- self.block_size = config.block_size
1012
- self.adaptive = config.adaptive
1013
- self.mask_first_token = config.mask_first_token
1014
- self.pool_with_global = config.pool_with_global
1015
 
1016
  self.embeddings = LSGBertEmbeddings(config)
1017
  self.encoder = LSGBertEncoder(config)
1018
- self.pooler = LSGBertPooler(config) if add_pooling_layer else None
1019
 
1020
  if config.add_cross_attention:
1021
  logger.warning(
@@ -1025,97 +1011,6 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
1025
  # Initialize weights and apply final processing
1026
  self.post_init()
1027
 
1028
- def forward(
1029
- self,
1030
- input_ids=None,
1031
- attention_mask=None,
1032
- token_type_ids=None,
1033
- position_ids=None,
1034
- head_mask=None,
1035
- inputs_embeds=None,
1036
- encoder_hidden_states=None,
1037
- encoder_attention_mask=None,
1038
- past_key_values=None,
1039
- use_cache=None,
1040
- output_attentions=None,
1041
- output_hidden_states=None,
1042
- return_dict=None
1043
- ):
1044
-
1045
- inputs_ = input_ids if input_ids is not None else inputs_embeds
1046
- n, t = inputs_.size()[:2]
1047
-
1048
- if attention_mask is None:
1049
- attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
1050
- if self.mask_first_token:
1051
- attention_mask[:,0] = 0
1052
- if token_type_ids is None:
1053
- token_type_ids = torch.zeros(n, t, device=inputs_.device).long()
1054
-
1055
- b = self.block_size * 2
1056
- pad = t % self.block_size
1057
-
1058
- # Check if t is multiple of block_size and pad
1059
- if self.adaptive and t > b and pad > 0:
1060
- pad_length = self.block_size - pad
1061
- if input_ids is not None:
1062
- input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
1063
- else:
1064
- inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
1065
-
1066
- attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
1067
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
1068
-
1069
- if position_ids is not None:
1070
- position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
1071
-
1072
- n, t_ = attention_mask.size()
1073
-
1074
- encoder_outputs = super().forward(
1075
- input_ids=input_ids,
1076
- attention_mask=attention_mask,
1077
- token_type_ids=token_type_ids,
1078
- position_ids=position_ids,
1079
- head_mask=head_mask,
1080
- inputs_embeds=inputs_embeds,
1081
- encoder_hidden_states=encoder_hidden_states,
1082
- encoder_attention_mask=encoder_attention_mask,
1083
- past_key_values=past_key_values,
1084
- use_cache=use_cache,
1085
- output_attentions=output_attentions,
1086
- output_hidden_states=output_hidden_states,
1087
- return_dict=return_dict
1088
- )
1089
-
1090
- context = encoder_outputs[0]
1091
- if self.pool_with_global:
1092
- context[:, self.num_global_tokens] = context[:, 0]
1093
-
1094
- diff = t - t_
1095
- n, _, d = context.size()
1096
- context = context[..., self.num_global_tokens:, :]
1097
-
1098
- # Adapt sequence to initial shape
1099
- if diff < 0:
1100
- context = context[:, :t]
1101
-
1102
- encoder_outputs.last_hidden_state = context
1103
-
1104
- sequence_output = encoder_outputs[0]
1105
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1106
-
1107
- if not return_dict:
1108
- return (sequence_output, pooled_output) + encoder_outputs[1:]
1109
-
1110
- return BaseModelOutputWithPoolingAndCrossAttentions(
1111
- last_hidden_state=sequence_output,
1112
- pooler_output=pooled_output,
1113
- past_key_values=encoder_outputs.past_key_values,
1114
- hidden_states=encoder_outputs.hidden_states,
1115
- attentions=encoder_outputs.attentions,
1116
- cross_attentions=encoder_outputs.cross_attentions,
1117
- )
1118
-
1119
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1120
 
1121
  # Do not rely on original triangular mask from BERT/RoBERTa for causalLM
@@ -1134,33 +1029,33 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
1134
  return extended_attention_mask
1135
 
1136
 
1137
- class LSGBertForPreTraining(LSGBertPreTrainedModel):
1138
 
1139
  def __init__(self, config):
1140
 
1141
- super().__init__(config)
1142
 
1143
  self.bert = LSGBertModel(config)
1144
- self.cls = LSGBertPreTrainingHeads(config)
1145
 
1146
  # Initialize weights and apply final processing
1147
  self.post_init()
1148
 
1149
 
1150
- class LSGBertLMHeadModel(BertLMHeadModel):
1151
 
1152
  _keys_to_ignore_on_load_unexpected = [r"pooler"]
1153
  _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1154
 
1155
  def __init__(self, config):
1156
 
1157
- BertPreTrainedModel.__init__(self, config)
1158
 
1159
  if not config.is_decoder:
1160
  logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
1161
 
1162
  self.bert = LSGBertModel(config, add_pooling_layer=False)
1163
- self.cls = LSGBertOnlyMLMHead(config)
1164
 
1165
  # Initialize weights and apply final processing
1166
  self.post_init()
@@ -1187,7 +1082,7 @@ class LSGBertForMaskedLM(LSGBertPreTrainedModel, BertForMaskedLM):
1187
  )
1188
 
1189
  self.bert = LSGBertModel(config, add_pooling_layer=False)
1190
- self.cls = LSGBertOnlyMLMHead(config)
1191
 
1192
  # Initialize weights and apply final processing
1193
  self.post_init()
@@ -1200,7 +1095,7 @@ class LSGBertForNextSentencePrediction(LSGBertPreTrainedModel, BertForNextSenten
1200
  LSGBertPreTrainedModel.__init__(self, config)
1201
 
1202
  self.bert = LSGBertModel(config)
1203
- self.cls = LSGBertOnlyNSPHead(config)
1204
 
1205
  # Initialize weights and apply final processing
1206
  self.post_init()
 
53
  self.sparse_block_size = sparse_block_size
54
  self.sparsity_factor = sparsity_factor
55
  self.sparsity_type = sparsity_type
56
+
57
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
58
  logger.warning(
59
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
60
+ setting sparsity_type=None, computation will skip sparse attention")
61
  self.sparsity_type = None
62
 
63
  if self.sparsity_type in ["stride", "block_stride"]:
 
65
  logger.warning(
66
  "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
67
  )
68
+
69
  if self.num_global_tokens < 1:
70
  logger.warning(
71
  "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
 
73
  self.num_global_tokens = 1
74
  elif self.num_global_tokens > 512:
75
  logger.warning(
76
+ "[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
77
  )
78
  self.num_global_tokens = 512
79
 
80
  if self.sparsity_factor > 0:
81
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
82
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
83
+
84
+ if self.mask_first_token and not pool_with_global:
85
+ logger.warning(
86
+ "[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
87
+ self.pool_with_global = True
88
+
89
+ if hasattr(self, "position_embedding_type"):
90
+ if self.position_embedding_type != "absolute":
91
+ logger.warning(
92
+ "[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
93
 
94
 
95
  class BaseSelfAttention(nn.Module):
 
199
  diagonal=-1
200
  )
201
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
202
+ attention_scores[..., -causal_shape[0]:, -causal_shape[1] + 1:] = causal_mask[:, 1:]
203
 
204
  del attention_mask
205
 
 
540
  keys = keys.sum(dim=-2) / (mask + 1e-6)
541
  values = values.sum(dim=-2) / (mask + 1e-6)
542
 
543
+ mask = (1. - mask.clamp(0, 1))
544
+ mask *= torch.finfo(mask.dtype).min
545
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
546
 
547
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
606
  keys /= mask + 1e-8
607
  values /= mask + 1e-8
608
 
609
+ mask = (1. - mask.clamp(0, 1))
610
+ mask *= torch.finfo(mask.dtype).min
611
 
612
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
613
 
 
708
  output_attentions=output_attentions
709
  )
710
 
 
 
711
  return outputs
712
 
713
  def causal_forward(
 
873
  return x.reshape(n, h, -1, chunk_size, d)
874
 
875
 
 
 
 
 
 
 
876
  class LSGAttention(BertAttention):
877
 
878
  def __init__(self, config):
 
880
  nn.Module.__init__(self)
881
 
882
  self.self = LSGSelfAttention(config)
883
+ self.output = BertSelfOutput(config)
884
  self.pruned_heads = set()
885
 
886
 
 
 
 
 
 
 
 
 
 
 
 
 
887
  class LSGBertLayer(BertLayer):
888
 
889
  def __init__(self, config):
890
 
891
+ super().__init__(config)
892
 
 
 
893
  self.attention = LSGAttention(config)
 
 
894
  if self.add_cross_attention:
895
  if not self.is_decoder:
896
  assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
897
  self.crossattention = LSGAttention(config)
 
 
898
 
899
 
900
  class LSGBertEncoder(BertEncoder):
901
 
902
  def __init__(self, config):
903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904
  super().__init__(config)
905
 
906
+ self.layer = nn.ModuleList([LSGBertLayer(config) for _ in range(config.num_hidden_layers)])
907
 
908
+ assert hasattr(config, "num_global_tokens")
909
+ self.num_global_tokens = config.num_global_tokens
910
+ self.pad_idx = config.pad_token_id
911
 
912
+ assert hasattr(config, "block_size") and hasattr(config, "adaptive")
913
+ self.block_size = config.block_size
914
+ self.adaptive = config.adaptive
915
+ self.mask_first_token = config.mask_first_token
916
+ self.pool_with_global = config.pool_with_global
917
 
918
+ def forward(
919
+ self,
920
+ hidden_states: torch.Tensor,
921
+ attention_mask: Optional[torch.FloatTensor] = None,
922
+ head_mask: Optional[torch.FloatTensor] = None,
923
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
924
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
925
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
926
+ use_cache: Optional[bool] = None,
927
+ output_attentions: Optional[bool] = False,
928
+ output_hidden_states: Optional[bool] = False,
929
+ return_dict: Optional[bool] = True,
930
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
931
+
932
+ mask_value = torch.finfo(attention_mask.dtype).min
933
+ n, _, __, t = attention_mask.size()
934
 
935
+ if not (self.config.is_decoder and encoder_hidden_states is not None):
 
 
 
 
 
 
 
 
 
936
 
937
+ b = self.block_size * 2
938
+ pad = t % self.block_size
939
+
940
+ # Check if t is multiple of block_size and pad
941
+ if self.adaptive and t > b and pad > 0:
942
+ pad_length = self.block_size - pad
943
+ hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
944
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=mask_value)
945
 
946
+ if self.mask_first_token:
947
+ attention_mask[..., 0] = mask_value
948
 
949
+ encoder_outputs = super().forward(
950
+ hidden_states=hidden_states,
951
+ attention_mask=attention_mask,
952
+ head_mask=head_mask,
953
+ encoder_hidden_states=encoder_hidden_states,
954
+ encoder_attention_mask=encoder_attention_mask,
955
+ past_key_values=past_key_values,
956
+ use_cache=use_cache,
957
+ output_attentions=output_attentions,
958
+ output_hidden_states=output_hidden_states,
959
+ return_dict=return_dict
960
+ )
961
 
962
+ sequence_output = encoder_outputs[0]
963
+ if self.pool_with_global:
964
+ sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
 
 
 
 
 
 
 
965
 
966
+ # Adapt sequence to initial shape
967
+ sequence_output = sequence_output[..., self.num_global_tokens: t + self.num_global_tokens, :]
968
 
969
+ if not return_dict:
970
+ return (sequence_output, ) + encoder_outputs[1:]
971
 
972
+ encoder_outputs.last_hidden_state = sequence_output
973
+ return encoder_outputs
 
 
 
974
 
975
  class LSGBertPreTrainedModel(BertPreTrainedModel):
976
  """
 
998
  LSGBertPreTrainedModel.__init__(self, config)
999
 
1000
  self.config = config
 
 
 
 
 
 
 
 
 
1001
 
1002
  self.embeddings = LSGBertEmbeddings(config)
1003
  self.encoder = LSGBertEncoder(config)
1004
+ self.pooler = BertPooler(config) if add_pooling_layer else None
1005
 
1006
  if config.add_cross_attention:
1007
  logger.warning(
 
1011
  # Initialize weights and apply final processing
1012
  self.post_init()
1013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1014
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1015
 
1016
  # Do not rely on original triangular mask from BERT/RoBERTa for causalLM
 
1029
  return extended_attention_mask
1030
 
1031
 
1032
+ class LSGBertForPreTraining(LSGBertPreTrainedModel, BertForPreTraining):
1033
 
1034
  def __init__(self, config):
1035
 
1036
+ LSGBertPreTrainedModel.__init__(self, config)
1037
 
1038
  self.bert = LSGBertModel(config)
1039
+ self.cls = BertPreTrainingHeads(config)
1040
 
1041
  # Initialize weights and apply final processing
1042
  self.post_init()
1043
 
1044
 
1045
+ class LSGBertLMHeadModel(LSGBertPreTrainedModel, BertLMHeadModel):
1046
 
1047
  _keys_to_ignore_on_load_unexpected = [r"pooler"]
1048
  _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1049
 
1050
  def __init__(self, config):
1051
 
1052
+ LSGBertPreTrainedModel.__init__(self, config)
1053
 
1054
  if not config.is_decoder:
1055
  logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
1056
 
1057
  self.bert = LSGBertModel(config, add_pooling_layer=False)
1058
+ self.cls = BertOnlyMLMHead(config)
1059
 
1060
  # Initialize weights and apply final processing
1061
  self.post_init()
 
1082
  )
1083
 
1084
  self.bert = LSGBertModel(config, add_pooling_layer=False)
1085
+ self.cls = BertOnlyMLMHead(config)
1086
 
1087
  # Initialize weights and apply final processing
1088
  self.post_init()
 
1095
  LSGBertPreTrainedModel.__init__(self, config)
1096
 
1097
  self.bert = LSGBertModel(config)
1098
+ self.cls = BertOnlyNSPHead(config)
1099
 
1100
  # Initialize weights and apply final processing
1101
  self.post_init()