Fairly-Multilingual-ModernBERT-Embed-BE / pick_best_tokenizer.py
FremyCompany's picture
Update pick_best_tokenizer.py
a709e8b verified
import os
import json
import random
from typing import List, Dict
from transformers import PreTrainedTokenizer, AutoTokenizer
#pbt_log = []
class PickBestTokenizer(PreTrainedTokenizer):
def __init__(self, tokenizers: List[PreTrainedTokenizer], **kwargs):
self.model_input_names = ["input_ids", "attention_mask"]
self.tokenizers = [AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer,str) else tokenizer for tokenizer in tokenizers]
self.tokenizers_offsets = []
self.vocab = {}
self._vocab_size = sum(len(tokenizer) for tokenizer in self.tokenizers)
self.pad_token = '[0]'+(self.tokenizers[0].pad_token if self.tokenizers[0].pad_token else self.tokenizers[0].eos_token)
offset = 0
for i, tokenizer in enumerate(self.tokenizers):
tokenizer_id = f"[{i}]"
self.tokenizers_offsets.append(offset)
for token, token_id in tokenizer.get_vocab().items():
self.vocab[tokenizer_id + token] = token_id + offset
offset += len(tokenizer)
super().__init__(**kwargs)
@property
def vocab_size(self) -> int:
return self._vocab_size
def get_vocab(self) -> Dict[str, int]:
return self.vocab
def tokenize(self, text: str, **kwargs) -> List[str]:
# Tokenize the text with all possible tokenizers
tokenized_texts = [
[f"[{i}]" + tok for tok in tokenizer.tokenize(text, **kwargs)]
for i, tokenizer in enumerate(self.tokenizers)
]
# Ensure that in case of equal lengths, no tokenizer is favored
random.shuffle(tokenized_texts)
# Return the list of tokens which is shortest
best_tokenization = min(tokenized_texts, key=len)
# Log the output
#pbt_log.append((text, best_tokenization))
# Return the output
return best_tokenization
def convert_tokens_to_ids(self, tokens: List[str], **kwargs) -> List[int]:
if isinstance(tokens, str): return self.convert_tokens_to_ids([tokens])[0]
ids = []
for token in tokens:
tokenizer_id = int(token[1])
token_stripped = token[3:]
offset = self.tokenizers_offsets[tokenizer_id]
ids.append(self.tokenizers[tokenizer_id].convert_tokens_to_ids(token_stripped, **kwargs) + offset)
return ids
def convert_ids_to_tokens(self, ids: List[int], **kwargs) -> List[str]:
if isinstance(ids, int): return self.convert_ids_to_tokens([ids])[0]
tokens = []
for id in ids:
for i, offset in enumerate(self.tokenizers_offsets):
if id < offset + len(self.tokenizers[i]):
token_id = id - offset
tokens.append(f"[{i}]{self.tokenizers[i].convert_ids_to_tokens(token_id, **kwargs)}")
break
else:
raise ValueError(f"ID {id} is out of range for any tokenizer.")
return tokens
def _convert_token_to_id(self, token: str) -> int:
raise NotImplementedError("This method should not be used in this class.")
def _convert_id_to_token(self, index: int) -> str:
raise NotImplementedError("This method should not be used in this class.")
def save_pretrained(self, path, *args, **kwargs):
# ensure the save path exists
os.makedirs(path, exist_ok=True)
# save this file in the repository as `pick_best_tokenizer.py`
from pathlib import Path
source = Path(__file__)
destination = Path(path+'/pick_best_tokenizer.py')
destination.write_bytes(source.read_bytes())
# save the config
config = {
"tokenizer_class": "PickBestTokenizer",
"auto_map": ["pick_best_tokenizer.PickBestTokenizer", None],
"tokenizers": [tokenizer.name_or_path for tokenizer in self.tokenizers]
}
with open(path+'/tokenizer_config.json', 'w') as f:
json.dump(config, f)
# Example usage
#tokenizer_fr = AutoTokenizer.from_pretrained("tokenizers/fineweb2_fr")
#tokenizer_nl = AutoTokenizer.from_pretrained("tokenizers/fineweb2_nl")
#tokenizer_de = AutoTokenizer.from_pretrained("tokenizers/fineweb2_de")
#tokenizer_en = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
#pick_best_tokenizer = PickBestTokenizer([tokenizer_fr, tokenizer_nl, tokenizer_de, tokenizer_en])
PickBestTokenizer.register_for_auto_class("AutoTokenizer")