small fix
Browse files- README.md +1 -1
- modeling_lsg_distilbert.py +20 -7
README.md
CHANGED
@@ -6,7 +6,7 @@ tags:
|
|
6 |
---
|
7 |
|
8 |
# LSG model
|
9 |
-
**Transformers >= 4.
|
10 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
11 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
12 |
|
|
|
6 |
---
|
7 |
|
8 |
# LSG model
|
9 |
+
**Transformers >= 4.36.1**\
|
10 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
11 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
12 |
|
modeling_lsg_distilbert.py
CHANGED
@@ -100,14 +100,22 @@ class LSGEmbeddings(Embeddings):
|
|
100 |
|
101 |
self.block_size = config.block_size
|
102 |
|
103 |
-
def forward(self, input_ids,
|
104 |
"""
|
105 |
Parameters:
|
106 |
-
input_ids
|
|
|
|
|
|
|
|
|
|
|
107 |
Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
|
108 |
embeddings)
|
109 |
"""
|
110 |
-
|
|
|
|
|
|
|
111 |
|
112 |
# Setting the position-ids to the registered buffer in constructor, it helps
|
113 |
# when tracing the model without passing position-ids, solves
|
@@ -116,9 +124,8 @@ class LSGEmbeddings(Embeddings):
|
|
116 |
position_ids = self.position_ids[:, :seq_length]
|
117 |
else:
|
118 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
119 |
-
position_ids = position_ids.unsqueeze(0).
|
120 |
-
|
121 |
-
word_embeddings = self.word_embeddings(input_ids) if input_ids is not None else inputs_embeds
|
122 |
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
123 |
word_embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
|
124 |
|
@@ -853,6 +860,12 @@ class LSGDistilBertModel(LSGDistilBertPreTrainedModel, DistilBertModel):
|
|
853 |
self.transformer = LSGTransformer(config) # Encoder
|
854 |
self.num_global_tokens = config.num_global_tokens
|
855 |
# Initialize weights and apply final processing
|
|
|
|
|
|
|
|
|
|
|
|
|
856 |
self.post_init()
|
857 |
|
858 |
|
@@ -952,4 +965,4 @@ try:
|
|
952 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
953 |
except:
|
954 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
955 |
-
warn("Update to transformers >= 4.
|
|
|
100 |
|
101 |
self.block_size = config.block_size
|
102 |
|
103 |
+
def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
|
104 |
"""
|
105 |
Parameters:
|
106 |
+
input_ids (torch.Tensor):
|
107 |
+
torch.tensor(bs, max_seq_length) The token ids to embed.
|
108 |
+
input_embeds (*optional*, torch.Tensor):
|
109 |
+
The pre-computed word embeddings. Can only be passed if the input ids are `None`.
|
110 |
+
|
111 |
+
|
112 |
Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
|
113 |
embeddings)
|
114 |
"""
|
115 |
+
if input_ids is not None:
|
116 |
+
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
|
117 |
+
|
118 |
+
seq_length = word_embeddings.size(1)
|
119 |
|
120 |
# Setting the position-ids to the registered buffer in constructor, it helps
|
121 |
# when tracing the model without passing position-ids, solves
|
|
|
124 |
position_ids = self.position_ids[:, :seq_length]
|
125 |
else:
|
126 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
127 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
128 |
+
|
|
|
129 |
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
130 |
word_embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
|
131 |
|
|
|
860 |
self.transformer = LSGTransformer(config) # Encoder
|
861 |
self.num_global_tokens = config.num_global_tokens
|
862 |
# Initialize weights and apply final processing
|
863 |
+
|
864 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
865 |
+
if self._use_flash_attention_2:
|
866 |
+
logger.warning(
|
867 |
+
"[WARNING flash-attention]: LSG doesnt support flash-attention currently"
|
868 |
+
)
|
869 |
self.post_init()
|
870 |
|
871 |
|
|
|
965 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
966 |
except:
|
967 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
968 |
+
warn("Update to transformers >= 4.36.1 to fix.")
|