ccdv commited on
Commit
df9960d
·
1 Parent(s): 988b4fd

small fix with torch.finfo

Browse files
Files changed (1) hide show
  1. modeling_lsg_bert.py +71 -101
modeling_lsg_bert.py CHANGED
@@ -199,7 +199,7 @@ class CausalAttentionProduct(nn.Module):
199
  diagonal=-1
200
  )
201
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
202
- attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
203
 
204
  del attention_mask
205
 
@@ -540,7 +540,8 @@ class LSGSelfAttention(BaseSelfAttention):
540
  keys = keys.sum(dim=-2) / (mask + 1e-6)
541
  values = values.sum(dim=-2) / (mask + 1e-6)
542
 
543
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
544
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
545
 
546
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -605,7 +606,8 @@ class LSGSelfAttention(BaseSelfAttention):
605
  keys /= mask + 1e-8
606
  values /= mask + 1e-8
607
 
608
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
609
 
610
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
611
 
@@ -903,6 +905,72 @@ class LSGBertEncoder(BertEncoder):
903
 
904
  self.layer = nn.ModuleList([LSGBertLayer(config) for _ in range(config.num_hidden_layers)])
905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906
 
907
  class LSGBertPreTrainedModel(BertPreTrainedModel):
908
  """
@@ -930,15 +998,6 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
930
  LSGBertPreTrainedModel.__init__(self, config)
931
 
932
  self.config = config
933
- assert hasattr(config, "num_global_tokens")
934
- self.num_global_tokens = config.num_global_tokens
935
- self.pad_idx = config.pad_token_id
936
-
937
- assert hasattr(config, "block_size") and hasattr(config, "adaptive")
938
- self.block_size = config.block_size
939
- self.adaptive = config.adaptive
940
- self.mask_first_token = config.mask_first_token
941
- self.pool_with_global = config.pool_with_global
942
 
943
  self.embeddings = LSGBertEmbeddings(config)
944
  self.encoder = LSGBertEncoder(config)
@@ -952,95 +1011,6 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
952
  # Initialize weights and apply final processing
953
  self.post_init()
954
 
955
- def forward(
956
- self,
957
- input_ids=None,
958
- attention_mask=None,
959
- token_type_ids=None,
960
- position_ids=None,
961
- head_mask=None,
962
- inputs_embeds=None,
963
- encoder_hidden_states=None,
964
- encoder_attention_mask=None,
965
- past_key_values=None,
966
- use_cache=None,
967
- output_attentions=None,
968
- output_hidden_states=None,
969
- return_dict=None
970
- ):
971
-
972
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
973
- output_hidden_states = (
974
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
975
- )
976
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
977
-
978
- inputs_ = input_ids if input_ids is not None else inputs_embeds
979
- n, t = inputs_.size()[:2]
980
-
981
- if attention_mask is None:
982
- attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
983
- if self.mask_first_token:
984
- attention_mask[:,0] = 0
985
- if token_type_ids is None:
986
- token_type_ids = torch.zeros(n, t, device=inputs_.device).long()
987
-
988
- b = self.block_size * 2
989
- pad = t % self.block_size
990
-
991
- # Check if t is multiple of block_size and pad
992
- if self.adaptive and t > b and pad > 0:
993
- pad_length = self.block_size - pad
994
- if input_ids is not None:
995
- input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
996
- else:
997
- inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
998
-
999
- attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
1000
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
1001
-
1002
- if position_ids is not None:
1003
- position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
1004
-
1005
- n, t_ = attention_mask.size()
1006
-
1007
- encoder_outputs = super().forward(
1008
- input_ids=input_ids,
1009
- attention_mask=attention_mask,
1010
- token_type_ids=token_type_ids,
1011
- position_ids=position_ids,
1012
- head_mask=head_mask,
1013
- inputs_embeds=inputs_embeds,
1014
- encoder_hidden_states=encoder_hidden_states,
1015
- encoder_attention_mask=encoder_attention_mask,
1016
- past_key_values=past_key_values,
1017
- use_cache=use_cache,
1018
- output_attentions=output_attentions,
1019
- output_hidden_states=output_hidden_states,
1020
- return_dict=return_dict
1021
- )
1022
-
1023
- sequence_output = encoder_outputs[0]
1024
- if self.pool_with_global:
1025
- sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
1026
-
1027
- diff = t - t_
1028
- n, _, d = sequence_output.size()
1029
- sequence_output = sequence_output[..., self.num_global_tokens:, :]
1030
-
1031
- # Adapt sequence to initial shape
1032
- if diff < 0:
1033
- sequence_output = sequence_output[:, :t]
1034
-
1035
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1036
-
1037
- if not return_dict:
1038
- return (sequence_output, pooled_output) + encoder_outputs[1:]
1039
-
1040
- encoder_outputs.last_hidden_state = sequence_output
1041
- encoder_outputs.pooler_output = pooled_output
1042
- return encoder_outputs
1043
-
1044
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1045
 
1046
  # Do not rely on original triangular mask from BERT/RoBERTa for causalLM
 
199
  diagonal=-1
200
  )
201
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
202
+ attention_scores[..., -causal_shape[0]:, -causal_shape[1] + 1:] = causal_mask[:, 1:]
203
 
204
  del attention_mask
205
 
 
540
  keys = keys.sum(dim=-2) / (mask + 1e-6)
541
  values = values.sum(dim=-2) / (mask + 1e-6)
542
 
543
+ mask = (1. - mask.clamp(0, 1))
544
+ mask *= torch.finfo(mask.dtype).min
545
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
546
 
547
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
606
  keys /= mask + 1e-8
607
  values /= mask + 1e-8
608
 
609
+ mask = (1. - mask.clamp(0, 1))
610
+ mask *= torch.finfo(mask.dtype).min
611
 
612
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
613
 
 
905
 
906
  self.layer = nn.ModuleList([LSGBertLayer(config) for _ in range(config.num_hidden_layers)])
907
 
908
+ assert hasattr(config, "num_global_tokens")
909
+ self.num_global_tokens = config.num_global_tokens
910
+ self.pad_idx = config.pad_token_id
911
+
912
+ assert hasattr(config, "block_size") and hasattr(config, "adaptive")
913
+ self.block_size = config.block_size
914
+ self.adaptive = config.adaptive
915
+ self.mask_first_token = config.mask_first_token
916
+ self.pool_with_global = config.pool_with_global
917
+
918
+ def forward(
919
+ self,
920
+ hidden_states: torch.Tensor,
921
+ attention_mask: Optional[torch.FloatTensor] = None,
922
+ head_mask: Optional[torch.FloatTensor] = None,
923
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
924
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
925
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
926
+ use_cache: Optional[bool] = None,
927
+ output_attentions: Optional[bool] = False,
928
+ output_hidden_states: Optional[bool] = False,
929
+ return_dict: Optional[bool] = True,
930
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
931
+
932
+ mask_value = torch.finfo(attention_mask.dtype).min
933
+ n, _, __, t = attention_mask.size()
934
+
935
+ if not (self.config.is_decoder and encoder_hidden_states is not None):
936
+
937
+ b = self.block_size * 2
938
+ pad = t % self.block_size
939
+
940
+ # Check if t is multiple of block_size and pad
941
+ if self.adaptive and t > b and pad > 0:
942
+ pad_length = self.block_size - pad
943
+ hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
944
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=mask_value)
945
+
946
+ if self.mask_first_token:
947
+ attention_mask[..., 0] = mask_value
948
+
949
+ encoder_outputs = super().forward(
950
+ hidden_states=hidden_states,
951
+ attention_mask=attention_mask,
952
+ head_mask=head_mask,
953
+ encoder_hidden_states=encoder_hidden_states,
954
+ encoder_attention_mask=encoder_attention_mask,
955
+ past_key_values=past_key_values,
956
+ use_cache=use_cache,
957
+ output_attentions=output_attentions,
958
+ output_hidden_states=output_hidden_states,
959
+ return_dict=return_dict
960
+ )
961
+
962
+ sequence_output = encoder_outputs[0]
963
+ if self.pool_with_global:
964
+ sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
965
+
966
+ # Adapt sequence to initial shape
967
+ sequence_output = sequence_output[..., self.num_global_tokens: t + self.num_global_tokens, :]
968
+
969
+ if not return_dict:
970
+ return (sequence_output, ) + encoder_outputs[1:]
971
+
972
+ encoder_outputs.last_hidden_state = sequence_output
973
+ return encoder_outputs
974
 
975
  class LSGBertPreTrainedModel(BertPreTrainedModel):
976
  """
 
998
  LSGBertPreTrainedModel.__init__(self, config)
999
 
1000
  self.config = config
 
 
 
 
 
 
 
 
 
1001
 
1002
  self.embeddings = LSGBertEmbeddings(config)
1003
  self.encoder = LSGBertEncoder(config)
 
1011
  # Initialize weights and apply final processing
1012
  self.post_init()
1013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1014
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1015
 
1016
  # Do not rely on original triangular mask from BERT/RoBERTa for causalLM