|
import os |
|
import pathlib |
|
import rwkv_world_tokenizer |
|
from typing import List, Tuple, Callable |
|
|
|
def add_tokenizer_argument(parser) -> None: |
|
parser.add_argument( |
|
'tokenizer', |
|
help='Tokenizer to use; supported tokenizers: auto (guess from n_vocab), 20B, world', |
|
nargs='?', |
|
type=str, |
|
default='auto' |
|
) |
|
|
|
def get_tokenizer(tokenizer_name: str, n_vocab: int) -> Tuple[ |
|
Callable[[List[int]], str], |
|
Callable[[str], List[int]] |
|
]: |
|
if tokenizer_name == 'auto': |
|
if n_vocab == 50277: |
|
tokenizer_name = '20B' |
|
elif n_vocab == 65536: |
|
tokenizer_name = 'world' |
|
else: |
|
raise ValueError(f'Can not guess the tokenizer from n_vocab value of {n_vocab}') |
|
|
|
parent: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent |
|
|
|
if tokenizer_name == 'world': |
|
print('Loading World v20230424 tokenizer') |
|
return rwkv_world_tokenizer.get_world_tokenizer_v20230424() |
|
elif tokenizer_name == '20B': |
|
print('Loading 20B tokenizer') |
|
import tokenizers |
|
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(parent / '20B_tokenizer.json')) |
|
return tokenizer.decode, lambda x: tokenizer.encode(x).ids |
|
else: |
|
raise ValueError(f'Unknown tokenizer {tokenizer_name}') |
|
|