w32zhong commited on
Commit
677d4b0
1 Parent(s): 85e3474

update test.py

Browse files
Files changed (1) hide show
  1. test.py +5 -9
test.py CHANGED
@@ -3,8 +3,8 @@ import os
3
  import fire
4
  import torch
5
  from functools import partial
6
- from transformers import BertTokenizer
7
- from transformers import BertForPreTraining
8
  from pya0.preprocess import preprocess_for_transformer
9
 
10
 
@@ -25,14 +25,10 @@ def classifier_hook(tokenizer, tokens, topk, module, inputs, outputs):
25
  str(tokenizer.convert_ids_to_tokens(top_cands)))
26
 
27
 
28
- def test(
29
- test_file='test.txt',
30
- ckpt_bert='ckpt/bert-pretrained-for-math-7ep/6_3_1382',
31
- ckpt_tokenizer='ckpt/bert-tokenizer-for-math'
32
- ):
33
 
34
- tokenizer = BertTokenizer.from_pretrained(ckpt_tokenizer)
35
- model = BertForPreTraining.from_pretrained(ckpt_bert,
36
  tie_word_embeddings=True
37
  )
38
  with open(test_file, 'r') as fh:
 
3
  import fire
4
  import torch
5
  from functools import partial
6
+ from transformers import AutoTokenizer
7
+ from transformers import AutoModelForPreTraining
8
  from pya0.preprocess import preprocess_for_transformer
9
 
10
 
 
25
  str(tokenizer.convert_ids_to_tokens(top_cands)))
26
 
27
 
28
+ def test(model_name_or_path, tokenizer_name_or_path, test_file='test.txt'):
 
 
 
 
29
 
30
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
31
+ model = AutoModelForPreTraining.from_pretrained(model_name_or_path,
32
  tie_word_embeddings=True
33
  )
34
  with open(test_file, 'r') as fh: