guymorganb commited on
Commit
24ea396
·
1 Parent(s): 6db9dbd

updated saf tensors so they match

Browse files
Files changed (1) hide show
  1. modeling_lsg_bert.py +80 -4
modeling_lsg_bert.py CHANGED
@@ -23,7 +23,11 @@ from transformers.models.bert.modeling_bert import (
23
  )
24
  import torch
25
  import torch.nn as nn
26
- from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
 
 
 
 
27
  from transformers.models.bert.configuration_bert import BertConfig
28
  import math
29
  import sys
@@ -817,7 +821,7 @@ class LSGSelfAttention(BaseSelfAttention):
817
  n, h, t, d = query_layer.size()
818
 
819
  # Cat global mask
820
- attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
821
 
822
  # Use normal attention if local attention covers every tokens
823
  if t <= 2 * self.block_size + self.num_global_tokens:
@@ -1023,9 +1027,9 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
1023
 
1024
  def __init__(self, config, add_pooling_layer=True):
1025
 
1026
- # ensure your LSGBertModel inherits all the necessary fields introduced in the latest Transformers.
1027
  BertModel.__init__(self, config)
1028
-
1029
  LSGBertPreTrainedModel.__init__(self, config)
1030
 
1031
  self.config = config
@@ -1041,6 +1045,78 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
1041
 
1042
  # Initialize weights and apply final processing
1043
  self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1044
 
1045
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1046
 
 
23
  )
24
  import torch
25
  import torch.nn as nn
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutputWithPastAndCrossAttentions,
28
+ BaseModelOutputWithPoolingAndCrossAttentions
29
+ )
30
+
31
  from transformers.models.bert.configuration_bert import BertConfig
32
  import math
33
  import sys
 
821
  n, h, t, d = query_layer.size()
822
 
823
  # Cat global mask
824
+ # attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
825
 
826
  # Use normal attention if local attention covers every tokens
827
  if t <= 2 * self.block_size + self.num_global_tokens:
 
1027
 
1028
  def __init__(self, config, add_pooling_layer=True):
1029
 
1030
+ # 1) Initialize the standard BertModel
1031
  BertModel.__init__(self, config)
1032
+ # 2) Initialize our LSG PreTrained
1033
  LSGBertPreTrainedModel.__init__(self, config)
1034
 
1035
  self.config = config
 
1045
 
1046
  # Initialize weights and apply final processing
1047
  self.post_init()
1048
+
1049
+ def forward(
1050
+ self,
1051
+ input_ids=None,
1052
+ attention_mask=None,
1053
+ token_type_ids=None,
1054
+ position_ids=None,
1055
+ head_mask=None,
1056
+ inputs_embeds=None,
1057
+ encoder_hidden_states=None,
1058
+ encoder_attention_mask=None,
1059
+ past_key_values=None,
1060
+ use_cache=None,
1061
+ output_attentions=None,
1062
+ output_hidden_states=None,
1063
+ return_dict=None,
1064
+ ):
1065
+ # ----------------------------
1066
+ # 1) Use LSG embeddings
1067
+ embedding_output = self.embeddings(
1068
+ input_ids=input_ids,
1069
+ token_type_ids=token_type_ids,
1070
+ position_ids=position_ids,
1071
+ inputs_embeds=inputs_embeds,
1072
+ past_key_values_length=past_key_values[0][0].size(2) if past_key_values else 0
1073
+ ) if (input_ids is not None or inputs_embeds is not None) else None
1074
+
1075
+ # 2) If we have an attention mask and some global tokens, pad the mask
1076
+ # by `config.num_global_tokens` so it matches embedding_output.size(1).
1077
+ if attention_mask is not None and self.config.num_global_tokens > 0:
1078
+ # Original shape: (batch_size, seq_len)
1079
+ bsz, seq_len = attention_mask.shape
1080
+
1081
+ new_shape = (bsz, seq_len + self.config.num_global_tokens)
1082
+ extended_mask = torch.zeros(new_shape, dtype=attention_mask.dtype, device=attention_mask.device)
1083
+ # Fill from index `num_global_tokens` onward
1084
+ extended_mask[:, self.config.num_global_tokens:] = attention_mask
1085
+ attention_mask = extended_mask
1086
+
1087
+ # 3) Now call self.encoder with the updated mask
1088
+ encoder_outputs = self.encoder(
1089
+ hidden_states=embedding_output,
1090
+ attention_mask=attention_mask.unsqueeze(1).unsqueeze(2) if attention_mask is not None else None,
1091
+ head_mask=head_mask,
1092
+ encoder_hidden_states=encoder_hidden_states,
1093
+ encoder_attention_mask=encoder_attention_mask,
1094
+ past_key_values=past_key_values,
1095
+ use_cache=use_cache,
1096
+ output_attentions=output_attentions,
1097
+ output_hidden_states=output_hidden_states,
1098
+ return_dict=return_dict
1099
+ )
1100
+
1101
+ # 4) Grab the last hidden state
1102
+ sequence_output = encoder_outputs[0]
1103
+
1104
+ # 5) Optionally apply the pooler
1105
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1106
+
1107
+ # Return
1108
+ if not return_dict:
1109
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1110
+
1111
+
1112
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1113
+ last_hidden_state=sequence_output,
1114
+ pooler_output=pooled_output,
1115
+ past_key_values=encoder_outputs.past_key_values,
1116
+ hidden_states=encoder_outputs.hidden_states,
1117
+ attentions=encoder_outputs.attentions,
1118
+ cross_attentions=encoder_outputs.cross_attentions,
1119
+ )
1120
 
1121
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1122