File size: 4,511 Bytes
f8af4fc
 
 
 
 
 
 
 
 
 
 
 
 
 
a709e8b
f8af4fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import json
import random
from typing import List, Dict
from transformers import PreTrainedTokenizer, AutoTokenizer

#pbt_log = []
class PickBestTokenizer(PreTrainedTokenizer):
    def __init__(self, tokenizers: List[PreTrainedTokenizer], **kwargs):
        self.model_input_names = ["input_ids", "attention_mask"]
        self.tokenizers = [AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer,str) else tokenizer for tokenizer in tokenizers]
        self.tokenizers_offsets = []
        self.vocab = {}
        self._vocab_size = sum(len(tokenizer) for tokenizer in self.tokenizers)
        self.pad_token = '[0]'+(self.tokenizers[0].pad_token if self.tokenizers[0].pad_token else self.tokenizers[0].eos_token)
        
        offset = 0
        for i, tokenizer in enumerate(self.tokenizers):
            tokenizer_id = f"[{i}]"
            self.tokenizers_offsets.append(offset)
            for token, token_id in tokenizer.get_vocab().items():
                self.vocab[tokenizer_id + token] = token_id + offset
            offset += len(tokenizer)
        
        super().__init__(**kwargs)

    @property
    def vocab_size(self) -> int:
        return self._vocab_size

    def get_vocab(self) -> Dict[str, int]:
        return self.vocab

    def tokenize(self, text: str, **kwargs) -> List[str]:
        # Tokenize the text with all possible tokenizers
        tokenized_texts = [
            [f"[{i}]" + tok for tok in tokenizer.tokenize(text, **kwargs)]
            for i, tokenizer in enumerate(self.tokenizers)
        ]

        # Ensure that in case of equal lengths, no tokenizer is favored
        random.shuffle(tokenized_texts)
        
        # Return the list of tokens which is shortest
        best_tokenization = min(tokenized_texts, key=len)

        # Log the output
        #pbt_log.append((text, best_tokenization))

        # Return the output
        return best_tokenization

    def convert_tokens_to_ids(self, tokens: List[str], **kwargs) -> List[int]:
        if isinstance(tokens, str): return self.convert_tokens_to_ids([tokens])[0]
        ids = []
        for token in tokens:
            tokenizer_id = int(token[1])
            token_stripped = token[3:]
            offset = self.tokenizers_offsets[tokenizer_id]
            ids.append(self.tokenizers[tokenizer_id].convert_tokens_to_ids(token_stripped, **kwargs) + offset)
        return ids

    def convert_ids_to_tokens(self, ids: List[int], **kwargs) -> List[str]:
        if isinstance(ids, int): return self.convert_ids_to_tokens([ids])[0]
        tokens = []
        for id in ids:
            for i, offset in enumerate(self.tokenizers_offsets):
                if id < offset + len(self.tokenizers[i]):
                    token_id = id - offset
                    tokens.append(f"[{i}]{self.tokenizers[i].convert_ids_to_tokens(token_id, **kwargs)}")
                    break
            else:
                raise ValueError(f"ID {id} is out of range for any tokenizer.")
        return tokens

    def _convert_token_to_id(self, token: str) -> int:
        raise NotImplementedError("This method should not be used in this class.")

    def _convert_id_to_token(self, index: int) -> str:
        raise NotImplementedError("This method should not be used in this class.")

    def save_pretrained(self, path, *args, **kwargs):
        # ensure the save path exists
        os.makedirs(path, exist_ok=True)
        # save this file in the repository as `pick_best_tokenizer.py`
        from pathlib import Path
        source = Path(__file__)
        destination = Path(path+'/pick_best_tokenizer.py')
        destination.write_bytes(source.read_bytes())
        # save the config
        config = {
            "tokenizer_class": "PickBestTokenizer",
            "auto_map": ["pick_best_tokenizer.PickBestTokenizer", None],
            "tokenizers": [tokenizer.name_or_path for tokenizer in self.tokenizers]
        }
        with open(path+'/tokenizer_config.json', 'w') as f:
            json.dump(config, f)

# Example usage
#tokenizer_fr = AutoTokenizer.from_pretrained("tokenizers/fineweb2_fr")
#tokenizer_nl = AutoTokenizer.from_pretrained("tokenizers/fineweb2_nl")
#tokenizer_de = AutoTokenizer.from_pretrained("tokenizers/fineweb2_de")
#tokenizer_en = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
#pick_best_tokenizer = PickBestTokenizer([tokenizer_fr, tokenizer_nl, tokenizer_de, tokenizer_en])

PickBestTokenizer.register_for_auto_class("AutoTokenizer")