Minor changes for correct inference

#1
by tomaarsen HF staff - opened
Files changed (4) hide show
  1. README.md +1 -0
  2. config.json +3 -1
  3. model.py +2 -1
  4. tokenizer_config.json +4 -0
README.md CHANGED
@@ -4,6 +4,7 @@ datasets:
4
  - tiiuae/falcon-refinedweb
5
  language:
6
  - en
 
7
  ---
8
 
9
  # NeoBERT
 
4
  - tiiuae/falcon-refinedweb
5
  language:
6
  - en
7
+ library_name: transformers
8
  ---
9
 
10
  # NeoBERT
config.json CHANGED
@@ -4,7 +4,9 @@
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "model.NeoBERTConfig",
7
- "AutoModel": "model.NeoBERTLMHead"
 
 
8
  },
9
  "classifier_init_range": 0.02,
10
  "decoder_init_range": 0.02,
 
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "model.NeoBERTConfig",
7
+ "AutoModel": "model.NeoBERT",
8
+ "AutoModelForMaskedLM": "model.NeoBERTLMHead",
9
+ "AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification"
10
  },
11
  "classifier_init_range": 0.02,
12
  "decoder_init_range": 0.02,
model.py CHANGED
@@ -190,7 +190,7 @@ class EncoderBlock(nn.Module):
190
  query=xq.transpose(1, 2),
191
  key=xk.transpose(1, 2),
192
  value=xv.transpose(1, 2),
193
- attn_mask=attention_mask,
194
  dropout_p=0,
195
  ).transpose(1, 2)
196
 
@@ -199,6 +199,7 @@ class EncoderBlock(nn.Module):
199
 
200
  class NeoBERTPreTrainedModel(PreTrainedModel):
201
  config_class = NeoBERTConfig
 
202
  _supports_cache_class = True
203
 
204
  def _init_weights(self, module):
 
190
  query=xq.transpose(1, 2),
191
  key=xk.transpose(1, 2),
192
  value=xv.transpose(1, 2),
193
+ attn_mask=attention_mask.bool(),
194
  dropout_p=0,
195
  ).transpose(1, 2)
196
 
 
199
 
200
  class NeoBERTPreTrainedModel(PreTrainedModel):
201
  config_class = NeoBERTConfig
202
+ base_model_prefix = "model"
203
  _supports_cache_class = True
204
 
205
  def _init_weights(self, module):
tokenizer_config.json CHANGED
@@ -46,6 +46,10 @@
46
  "do_lower_case": true,
47
  "extra_special_tokens": {},
48
  "mask_token": "[MASK]",
 
 
 
 
49
  "model_max_length": 4096,
50
  "pad_token": "[PAD]",
51
  "sep_token": "[SEP]",
 
46
  "do_lower_case": true,
47
  "extra_special_tokens": {},
48
  "mask_token": "[MASK]",
49
+ "model_input_names": [
50
+ "input_ids",
51
+ "attention_mask"
52
+ ],
53
  "model_max_length": 4096,
54
  "pad_token": "[PAD]",
55
  "sep_token": "[SEP]",