ccdv commited on
Commit
137cdc2
·
1 Parent(s): 5657803
Files changed (2) hide show
  1. README.md +1 -1
  2. modeling_lsg_albert.py +47 -42
README.md CHANGED
@@ -8,7 +8,7 @@ pipeline_tag: fill-mask
8
  ---
9
 
10
  # LSG model
11
- **Transformers >= 4.35.2**\
12
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
13
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
14
 
 
8
  ---
9
 
10
  # LSG model
11
+ **Transformers >= 4.36.1**\
12
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
13
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
14
 
modeling_lsg_albert.py CHANGED
@@ -413,54 +413,54 @@ class LSGAlbertEmbeddings(AlbertEmbeddings):
413
  self.block_size = config.block_size
414
 
415
  def forward(
416
- self,
417
- input_ids=None,
418
- token_type_ids=None,
419
- position_ids=None,
420
- inputs_embeds=None,
421
- past_key_values_length=0,
422
- ) -> torch.Tensor:
423
- if input_ids is not None:
424
- input_shape = input_ids.size()
425
- else:
426
- input_shape = inputs_embeds.size()[:-1]
427
 
428
- seq_length = input_shape[1]
429
 
430
- if position_ids is None:
431
- position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
432
 
433
- # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
434
- # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
435
- # issue #5664
436
- if token_type_ids is None:
437
- if hasattr(self, "token_type_ids"):
438
- buffered_token_type_ids = self.token_type_ids[:, :seq_length]
439
- buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
440
- token_type_ids = buffered_token_type_ids_expanded
441
- else:
442
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
443
 
444
- if inputs_embeds is None:
445
- inputs_embeds = self.word_embeddings(input_ids)
446
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
447
 
448
- embeddings = inputs_embeds + token_type_embeddings
449
- if self.position_embedding_type == "absolute":
450
- position_embeddings = self.position_embeddings(position_ids)
451
- embeddings += position_embeddings
452
 
453
- n, t, d = embeddings.size()
454
-
455
- # Add global_tokens
456
- indexes = torch.arange(self.num_global_tokens, device=embeddings.device).reshape(1, -1)
457
- global_embeddings = self.global_embeddings(indexes)
458
- embeddings = torch.cat([global_embeddings.expand(n, -1, d), embeddings], dim=-2)
459
-
460
 
461
- embeddings = self.LayerNorm(embeddings)
462
- embeddings = self.dropout(embeddings)
463
- return embeddings
464
 
465
 
466
  class LSGSelfAttention(BaseSelfAttention):
@@ -907,6 +907,11 @@ class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel):
907
  self.pooler = None
908
  self.pooler_activation = None
909
 
 
 
 
 
 
910
  # Initialize weights and apply final processing
911
  self.post_init()
912
 
@@ -1015,4 +1020,4 @@ try:
1015
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1016
  except:
1017
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1018
- warn("Update to transformers >= 4.35.2 to fix.")
 
413
  self.block_size = config.block_size
414
 
415
  def forward(
416
+ self,
417
+ input_ids: Optional[torch.LongTensor] = None,
418
+ token_type_ids: Optional[torch.LongTensor] = None,
419
+ position_ids: Optional[torch.LongTensor] = None,
420
+ inputs_embeds: Optional[torch.FloatTensor] = None,
421
+ past_key_values_length: int = 0,
422
+ ) -> torch.Tensor:
423
+ if input_ids is not None:
424
+ input_shape = input_ids.size()
425
+ else:
426
+ input_shape = inputs_embeds.size()[:-1]
427
 
428
+ seq_length = input_shape[1]
429
 
430
+ if position_ids is None:
431
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
432
 
433
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
434
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
435
+ # issue #5664
436
+ if token_type_ids is None:
437
+ if hasattr(self, "token_type_ids"):
438
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
439
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
440
+ token_type_ids = buffered_token_type_ids_expanded
441
+ else:
442
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
443
 
444
+ if inputs_embeds is None:
445
+ inputs_embeds = self.word_embeddings(input_ids)
446
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
447
 
448
+ embeddings = inputs_embeds + token_type_embeddings
449
+ if self.position_embedding_type == "absolute":
450
+ position_embeddings = self.position_embeddings(position_ids)
451
+ embeddings += position_embeddings
452
 
453
+ n, t, d = embeddings.size()
454
+
455
+ # Add global_tokens
456
+ indexes = torch.arange(self.num_global_tokens, device=embeddings.device).reshape(1, -1)
457
+ global_embeddings = self.global_embeddings(indexes)
458
+ embeddings = torch.cat([global_embeddings.expand(n, -1, d), embeddings], dim=-2)
459
+
460
 
461
+ embeddings = self.LayerNorm(embeddings)
462
+ embeddings = self.dropout(embeddings)
463
+ return embeddings
464
 
465
 
466
  class LSGSelfAttention(BaseSelfAttention):
 
907
  self.pooler = None
908
  self.pooler_activation = None
909
 
910
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
911
+ if self._use_flash_attention_2:
912
+ logger.warning(
913
+ "[WARNING flash-attention]: LSG doesnt support flash-attention currently"
914
+ )
915
  # Initialize weights and apply final processing
916
  self.post_init()
917
 
 
1020
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1021
  except:
1022
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1023
+ warn("Update to transformers >= 4.36.1 to fix.")