update test.py
Browse files
test.py
CHANGED
@@ -3,8 +3,8 @@ import os
|
|
3 |
import fire
|
4 |
import torch
|
5 |
from functools import partial
|
6 |
-
from transformers import
|
7 |
-
from transformers import
|
8 |
from pya0.preprocess import preprocess_for_transformer
|
9 |
|
10 |
|
@@ -25,14 +25,10 @@ def classifier_hook(tokenizer, tokens, topk, module, inputs, outputs):
|
|
25 |
str(tokenizer.convert_ids_to_tokens(top_cands)))
|
26 |
|
27 |
|
28 |
-
def test(
|
29 |
-
test_file='test.txt',
|
30 |
-
ckpt_bert='ckpt/bert-pretrained-for-math-7ep/6_3_1382',
|
31 |
-
ckpt_tokenizer='ckpt/bert-tokenizer-for-math'
|
32 |
-
):
|
33 |
|
34 |
-
tokenizer =
|
35 |
-
model =
|
36 |
tie_word_embeddings=True
|
37 |
)
|
38 |
with open(test_file, 'r') as fh:
|
|
|
3 |
import fire
|
4 |
import torch
|
5 |
from functools import partial
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from transformers import AutoModelForPreTraining
|
8 |
from pya0.preprocess import preprocess_for_transformer
|
9 |
|
10 |
|
|
|
25 |
str(tokenizer.convert_ids_to_tokens(top_cands)))
|
26 |
|
27 |
|
28 |
+
def test(model_name_or_path, tokenizer_name_or_path, test_file='test.txt'):
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
31 |
+
model = AutoModelForPreTraining.from_pretrained(model_name_or_path,
|
32 |
tie_word_embeddings=True
|
33 |
)
|
34 |
with open(test_file, 'r') as fh:
|