ccdv commited on
Commit
79cac77
·
1 Parent(s): 3efaa9f
Files changed (2) hide show
  1. README.md +1 -1
  2. modeling_lsg_distilbert.py +20 -7
README.md CHANGED
@@ -6,7 +6,7 @@ tags:
6
  ---
7
 
8
  # LSG model
9
- **Transformers >= 4.35.2**\
10
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
11
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
12
 
 
6
  ---
7
 
8
  # LSG model
9
+ **Transformers >= 4.36.1**\
10
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
11
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
12
 
modeling_lsg_distilbert.py CHANGED
@@ -100,14 +100,22 @@ class LSGEmbeddings(Embeddings):
100
 
101
  self.block_size = config.block_size
102
 
103
- def forward(self, input_ids, inputs_embeds=None):
104
  """
105
  Parameters:
106
- input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.
 
 
 
 
 
107
  Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
108
  embeddings)
109
  """
110
- bs, seq_length = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2]
 
 
 
111
 
112
  # Setting the position-ids to the registered buffer in constructor, it helps
113
  # when tracing the model without passing position-ids, solves
@@ -116,9 +124,8 @@ class LSGEmbeddings(Embeddings):
116
  position_ids = self.position_ids[:, :seq_length]
117
  else:
118
  position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
119
- position_ids = position_ids.unsqueeze(0).expand(bs, seq_length) # (bs, max_seq_length)
120
-
121
- word_embeddings = self.word_embeddings(input_ids) if input_ids is not None else inputs_embeds
122
  position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
123
  word_embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
124
 
@@ -853,6 +860,12 @@ class LSGDistilBertModel(LSGDistilBertPreTrainedModel, DistilBertModel):
853
  self.transformer = LSGTransformer(config) # Encoder
854
  self.num_global_tokens = config.num_global_tokens
855
  # Initialize weights and apply final processing
 
 
 
 
 
 
856
  self.post_init()
857
 
858
 
@@ -952,4 +965,4 @@ try:
952
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
953
  except:
954
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
955
- warn("Update to transformers >= 4.35.2 to fix.")
 
100
 
101
  self.block_size = config.block_size
102
 
103
+ def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
104
  """
105
  Parameters:
106
+ input_ids (torch.Tensor):
107
+ torch.tensor(bs, max_seq_length) The token ids to embed.
108
+ input_embeds (*optional*, torch.Tensor):
109
+ The pre-computed word embeddings. Can only be passed if the input ids are `None`.
110
+
111
+
112
  Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
113
  embeddings)
114
  """
115
+ if input_ids is not None:
116
+ word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
117
+
118
+ seq_length = word_embeddings.size(1)
119
 
120
  # Setting the position-ids to the registered buffer in constructor, it helps
121
  # when tracing the model without passing position-ids, solves
 
124
  position_ids = self.position_ids[:, :seq_length]
125
  else:
126
  position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
127
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
128
+
 
129
  position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
130
  word_embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
131
 
 
860
  self.transformer = LSGTransformer(config) # Encoder
861
  self.num_global_tokens = config.num_global_tokens
862
  # Initialize weights and apply final processing
863
+
864
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
865
+ if self._use_flash_attention_2:
866
+ logger.warning(
867
+ "[WARNING flash-attention]: LSG doesnt support flash-attention currently"
868
+ )
869
  self.post_init()
870
 
871
 
 
965
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
966
  except:
967
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
968
+ warn("Update to transformers >= 4.36.1 to fix.")