File size: 4,781 Bytes
76b4794 6ee9897 76b4794 6ee9897 76b4794 6ee9897 76b4794 6ee9897 76b4794 6ee9897 3ba50ba 6ee9897 76b4794 6ee9897 76b4794 6ee9897 76b4794 6ee9897 76b4794 6ee9897 76b4794 6ee9897 76b4794 6ee9897 76b4794 6ee9897 76b4794 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from __future__ import annotations
from .mecab_tokenizer import MeCabTokenizer
import os
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
def save_stoi(stoi: dict[str, int], vocab_file: str):
"""単語IDの辞書を配列にしてvocab_fileに保存します。
Args:
stoi (dict[str, int]): 単語IDのマッピング
vocab_file (str): 保存するパス
Raises:
ValueError: IDが途切れているとエラーを起こします。
"""
with open(vocab_file, "w", encoding="utf-8") as writer:
index = 0
for token, token_index in sorted(stoi.items(), key=lambda kv: kv[1]):
if index != token_index:
raise ValueError(
"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!")
writer.write(token + "\n")
index += 1
def load_stoi(vocab_file: str) -> dict[str, int]:
"""ファイルから単語IDの辞書をロードします。
Args:
vocab_file (str): ファイルのパス
Returns:
dict[str, int]: 単語IDのマッピング
"""
stoi: dict[str, int] = {}
# ファイルから読み出し
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
# 単語IDのマッピングを生成します。
for index, token in enumerate(tokens):
token = token.rstrip("\n")
stoi[token] = index
return stoi
class FastTextJpTokenizer(MeCabTokenizer):
# Configが認識するのに必要です。
# https://huggingface.co/docs/transformers/custom_models#writing-a-custom-configuration
model_type = "fasttext_jp"
# vocab.txtを認識するのにおそらく必要。
vocab_files_names = VOCAB_FILES_NAMES
def __init__(self,
vocab_file: str,
hinshi: list[str] | None = None,
mecab_dicdir: str | None = None,
**kwargs):
"""初期化処理
Args:
vocab_file (str): vocab_fileのpath
hinshi (list[str] | None, optional): 抽出する品詞
mecab_dicdir (str | None, optional): dicrcのあるディレクトリ
"""
super().__init__(hinshi, mecab_dicdir, **kwargs)
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
" model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.stoi = load_stoi(vocab_file)
self.itos = dict([(ids, tok) for tok, ids in self.stoi.items()])
@property
def vocab_size(self) -> int:
"""ボキャブラリのサイズ
※PreTrainedTokenizerで実装すべき必須の関数。
Returns:
int: ボキャブラリのサイズ
"""
return len(self.stoi)
def _convert_token_to_id(self, token: str) -> int:
"""単語からID
※PreTrainedTokenizerで実装すべき必須の関数。
Args:
token (str): 単語
Returns:
int: ID
"""
return self.stoi[token]
def _convert_id_to_token(self, index: int) -> str:
"""IDから単語
※PreTrainedTokenizerで実装すべき必須の関数。
Args:
index (int): ID
Returns:
str: 単語
"""
return self.itos[index]
def save_vocabulary(self,
save_directory: str,
filename_prefix: str | None = None) -> tuple[str]:
"""ボキャブラリの保存
Args:
save_directory (str): 保存するディレクトリ。ファイル名はvocab.txtに固定
filename_prefix (str | None, optional): ファイルのprefix
Returns:
tuple[str]: ファイル名を返す。
"""
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") +
VOCAB_FILES_NAMES["vocab_file"])
else:
vocab_file = (filename_prefix +
"-" if filename_prefix else "") + save_directory
save_stoi(self.stoi, vocab_file)
return (vocab_file, )
# AutoTokenizerに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
FastTextJpTokenizer.register_for_auto_class("AutoTokenizer")
|