Minor changes for correct inference
#1
by
tomaarsen
HF staff
- opened
- README.md +1 -0
- config.json +3 -1
- model.py +2 -1
- tokenizer_config.json +4 -0
README.md
CHANGED
@@ -4,6 +4,7 @@ datasets:
|
|
4 |
- tiiuae/falcon-refinedweb
|
5 |
language:
|
6 |
- en
|
|
|
7 |
---
|
8 |
|
9 |
# NeoBERT
|
|
|
4 |
- tiiuae/falcon-refinedweb
|
5 |
language:
|
6 |
- en
|
7 |
+
library_name: transformers
|
8 |
---
|
9 |
|
10 |
# NeoBERT
|
config.json
CHANGED
@@ -4,7 +4,9 @@
|
|
4 |
],
|
5 |
"auto_map": {
|
6 |
"AutoConfig": "model.NeoBERTConfig",
|
7 |
-
"AutoModel": "model.
|
|
|
|
|
8 |
},
|
9 |
"classifier_init_range": 0.02,
|
10 |
"decoder_init_range": 0.02,
|
|
|
4 |
],
|
5 |
"auto_map": {
|
6 |
"AutoConfig": "model.NeoBERTConfig",
|
7 |
+
"AutoModel": "model.NeoBERT",
|
8 |
+
"AutoModelForMaskedLM": "model.NeoBERTLMHead",
|
9 |
+
"AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification"
|
10 |
},
|
11 |
"classifier_init_range": 0.02,
|
12 |
"decoder_init_range": 0.02,
|
model.py
CHANGED
@@ -190,7 +190,7 @@ class EncoderBlock(nn.Module):
|
|
190 |
query=xq.transpose(1, 2),
|
191 |
key=xk.transpose(1, 2),
|
192 |
value=xv.transpose(1, 2),
|
193 |
-
attn_mask=attention_mask,
|
194 |
dropout_p=0,
|
195 |
).transpose(1, 2)
|
196 |
|
@@ -199,6 +199,7 @@ class EncoderBlock(nn.Module):
|
|
199 |
|
200 |
class NeoBERTPreTrainedModel(PreTrainedModel):
|
201 |
config_class = NeoBERTConfig
|
|
|
202 |
_supports_cache_class = True
|
203 |
|
204 |
def _init_weights(self, module):
|
|
|
190 |
query=xq.transpose(1, 2),
|
191 |
key=xk.transpose(1, 2),
|
192 |
value=xv.transpose(1, 2),
|
193 |
+
attn_mask=attention_mask.bool(),
|
194 |
dropout_p=0,
|
195 |
).transpose(1, 2)
|
196 |
|
|
|
199 |
|
200 |
class NeoBERTPreTrainedModel(PreTrainedModel):
|
201 |
config_class = NeoBERTConfig
|
202 |
+
base_model_prefix = "model"
|
203 |
_supports_cache_class = True
|
204 |
|
205 |
def _init_weights(self, module):
|
tokenizer_config.json
CHANGED
@@ -46,6 +46,10 @@
|
|
46 |
"do_lower_case": true,
|
47 |
"extra_special_tokens": {},
|
48 |
"mask_token": "[MASK]",
|
|
|
|
|
|
|
|
|
49 |
"model_max_length": 4096,
|
50 |
"pad_token": "[PAD]",
|
51 |
"sep_token": "[SEP]",
|
|
|
46 |
"do_lower_case": true,
|
47 |
"extra_special_tokens": {},
|
48 |
"mask_token": "[MASK]",
|
49 |
+
"model_input_names": [
|
50 |
+
"input_ids",
|
51 |
+
"attention_mask"
|
52 |
+
],
|
53 |
"model_max_length": 4096,
|
54 |
"pad_token": "[PAD]",
|
55 |
"sep_token": "[SEP]",
|