fasttext-classification / fasttext_jp_tokenizer.py
Taizo Kaneko
commit files to HF hub
97c46f0
raw
history blame
4.78 kB
from __future__ import annotations
from .mecab_tokenizer import MeCabTokenizer
import os
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
def save_stoi(stoi: dict[str, int], vocab_file: str):
"""単語IDの辞書を配列にしてvocab_fileに保存します。
Args:
stoi (dict[str, int]): 単語IDのマッピング
vocab_file (str): 保存するパス
Raises:
ValueError: IDが途切れているとエラーを起こします。
"""
with open(vocab_file, "w", encoding="utf-8") as writer:
index = 0
for token, token_index in sorted(stoi.items(), key=lambda kv: kv[1]):
if index != token_index:
raise ValueError(
"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!")
writer.write(token + "\n")
index += 1
def load_stoi(vocab_file: str) -> dict[str, int]:
"""ファイルから単語IDの辞書をロードします。
Args:
vocab_file (str): ファイルのパス
Returns:
dict[str, int]: 単語IDのマッピング
"""
stoi: dict[str, int] = {}
# ファイルから読み出し
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
# 単語IDのマッピングを生成します。
for index, token in enumerate(tokens):
token = token.rstrip("\n")
stoi[token] = index
return stoi
class FastTextJpTokenizer(MeCabTokenizer):
# Configが認識するのに必要です。
# https://huggingface.co/docs/transformers/custom_models#writing-a-custom-configuration
model_type = "fasttext_jp"
# vocab.txtを認識するのにおそらく必要。
vocab_files_names = VOCAB_FILES_NAMES
def __init__(self,
vocab_file: str,
hinshi: list[str] | None = None,
mecab_dicdir: str | None = None,
**kwargs):
"""初期化処理
Args:
vocab_file (str): vocab_fileのpath
hinshi (list[str] | None, optional): 抽出する品詞
mecab_dicdir (str | None, optional): dicrcのあるディレクトリ
"""
super().__init__(hinshi, mecab_dicdir, **kwargs)
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
" model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.stoi = load_stoi(vocab_file)
self.itos = dict([(ids, tok) for tok, ids in self.stoi.items()])
@property
def vocab_size(self) -> int:
"""ボキャブラリのサイズ
※PreTrainedTokenizerで実装すべき必須の関数。
Returns:
int: ボキャブラリのサイズ
"""
return len(self.stoi)
def _convert_token_to_id(self, token: str) -> int:
"""単語からID
※PreTrainedTokenizerで実装すべき必須の関数。
Args:
token (str): 単語
Returns:
int: ID
"""
return self.stoi[token]
def _convert_id_to_token(self, index: int) -> str:
"""IDから単語
※PreTrainedTokenizerで実装すべき必須の関数。
Args:
index (int): ID
Returns:
str: 単語
"""
return self.itos[index]
def save_vocabulary(self,
save_directory: str,
filename_prefix: str | None = None) -> tuple[str]:
"""ボキャブラリの保存
Args:
save_directory (str): 保存するディレクトリ。ファイル名はvocab.txtに固定
filename_prefix (str | None, optional): ファイルのprefix
Returns:
tuple[str]: ファイル名を返す。
"""
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") +
VOCAB_FILES_NAMES["vocab_file"])
else:
vocab_file = (filename_prefix +
"-" if filename_prefix else "") + save_directory
save_stoi(self.stoi, vocab_file)
return (vocab_file, )
# AutoTokenizerに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
FastTextJpTokenizer.register_for_auto_class("AutoTokenizer")