import os import json import random from typing import List, Dict from transformers import PreTrainedTokenizer, AutoTokenizer #pbt_log = [] class PickBestTokenizer(PreTrainedTokenizer): def __init__(self, tokenizers: List[PreTrainedTokenizer], **kwargs): self.model_input_names = ["input_ids", "attention_mask"] self.tokenizers = [AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer,str) else tokenizer for tokenizer in tokenizers] self.tokenizers_offsets = [] self.vocab = {} self._vocab_size = sum(len(tokenizer) for tokenizer in self.tokenizers) self.pad_token = '[0]'+(self.tokenizers[0].pad_token if self.tokenizers[0].pad_token else self.tokenizers[0].eos_token) offset = 0 for i, tokenizer in enumerate(self.tokenizers): tokenizer_id = f"[{i}]" self.tokenizers_offsets.append(offset) for token, token_id in tokenizer.get_vocab().items(): self.vocab[tokenizer_id + token] = token_id + offset offset += len(tokenizer) super().__init__(**kwargs) @property def vocab_size(self) -> int: return self._vocab_size def get_vocab(self) -> Dict[str, int]: return self.vocab def tokenize(self, text: str, **kwargs) -> List[str]: # Tokenize the text with all possible tokenizers tokenized_texts = [ [f"[{i}]" + tok for tok in tokenizer.tokenize(text, **kwargs)] for i, tokenizer in enumerate(self.tokenizers) ] # Ensure that in case of equal lengths, no tokenizer is favored random.shuffle(tokenized_texts) # Return the list of tokens which is shortest best_tokenization = min(tokenized_texts, key=len) # Log the output #pbt_log.append((text, best_tokenization)) # Return the output return best_tokenization def convert_tokens_to_ids(self, tokens: List[str], **kwargs) -> List[int]: if isinstance(tokens, str): return self.convert_tokens_to_ids([tokens])[0] ids = [] for token in tokens: tokenizer_id = int(token[1]) token_stripped = token[3:] offset = self.tokenizers_offsets[tokenizer_id] ids.append(self.tokenizers[tokenizer_id].convert_tokens_to_ids(token_stripped, **kwargs) + offset) return ids def convert_ids_to_tokens(self, ids: List[int], **kwargs) -> List[str]: if isinstance(ids, int): return self.convert_ids_to_tokens([ids])[0] tokens = [] for id in ids: for i, offset in enumerate(self.tokenizers_offsets): if id < offset + len(self.tokenizers[i]): token_id = id - offset tokens.append(f"[{i}]{self.tokenizers[i].convert_ids_to_tokens(token_id, **kwargs)}") break else: raise ValueError(f"ID {id} is out of range for any tokenizer.") return tokens def _convert_token_to_id(self, token: str) -> int: raise NotImplementedError("This method should not be used in this class.") def _convert_id_to_token(self, index: int) -> str: raise NotImplementedError("This method should not be used in this class.") def save_pretrained(self, path, *args, **kwargs): # ensure the save path exists os.makedirs(path, exist_ok=True) # save this file in the repository as `pick_best_tokenizer.py` from pathlib import Path source = Path(__file__) destination = Path(path+'/pick_best_tokenizer.py') destination.write_bytes(source.read_bytes()) # save the config config = { "tokenizer_class": "PickBestTokenizer", "auto_map": ["pick_best_tokenizer.PickBestTokenizer", None], "tokenizers": [tokenizer.name_or_path for tokenizer in self.tokenizers] } with open(path+'/tokenizer_config.json', 'w') as f: json.dump(config, f) # Example usage #tokenizer_fr = AutoTokenizer.from_pretrained("tokenizers/fineweb2_fr") #tokenizer_nl = AutoTokenizer.from_pretrained("tokenizers/fineweb2_nl") #tokenizer_de = AutoTokenizer.from_pretrained("tokenizers/fineweb2_de") #tokenizer_en = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") #pick_best_tokenizer = PickBestTokenizer([tokenizer_fr, tokenizer_nl, tokenizer_de, tokenizer_en]) PickBestTokenizer.register_for_auto_class("AutoTokenizer")