Spaces:
Sleeping
Sleeping
nikhiljais
commited on
Create tokenizer.py
Browse files- tokenizer.py +199 -0
tokenizer.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import regex as re
|
2 |
+
from collections import Counter
|
3 |
+
from typing import List, Dict, Tuple, Set
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class SimpleBPETokenizer:
|
10 |
+
def __init__(self, vocab_size: int = 5000):
|
11 |
+
self.vocab_size = vocab_size
|
12 |
+
self.merges = {} # (int, int) -> int
|
13 |
+
self.vocab = set(range(256)) # Initial vocab is byte values 0-255
|
14 |
+
|
15 |
+
def _text_to_bytes(self, text: str) -> List[int]:
|
16 |
+
"""Convert text to list of byte values"""
|
17 |
+
return list(text.encode('utf-8'))
|
18 |
+
|
19 |
+
def _get_stats(self, ids: List[int]) -> Dict[Tuple[int, int], int]:
|
20 |
+
"""Count frequency of adjacent pairs"""
|
21 |
+
counts = {}
|
22 |
+
for pair in zip(ids, ids[1:]):
|
23 |
+
counts[pair] = counts.get(pair, 0) + 1
|
24 |
+
return counts
|
25 |
+
|
26 |
+
def _merge(self, ids: List[int], pair: Tuple[int, int], idx: int) -> List[int]:
|
27 |
+
"""Merge all occurrences of pair into new token idx"""
|
28 |
+
newids = []
|
29 |
+
i = 0
|
30 |
+
while i < len(ids):
|
31 |
+
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
|
32 |
+
newids.append(idx)
|
33 |
+
i += 2
|
34 |
+
else:
|
35 |
+
newids.append(ids[i])
|
36 |
+
i += 1
|
37 |
+
return newids
|
38 |
+
|
39 |
+
def fit(self, texts: List[str]):
|
40 |
+
"""Train tokenizer using byte-level BPE"""
|
41 |
+
# Convert all texts to byte sequences
|
42 |
+
logging.info("Converting texts to bytes...")
|
43 |
+
all_ids = []
|
44 |
+
for text in tqdm(texts, desc="Processing texts"):
|
45 |
+
all_ids.extend(self._text_to_bytes(text))
|
46 |
+
|
47 |
+
# Calculate number of merges needed
|
48 |
+
num_merges = self.vocab_size - 256 # 256 initial byte tokens
|
49 |
+
|
50 |
+
# Perform merges
|
51 |
+
next_id = 256
|
52 |
+
with tqdm(total=num_merges, desc="BPE merges") as pbar:
|
53 |
+
for i in range(num_merges):
|
54 |
+
# Get pair frequencies
|
55 |
+
stats = self._get_stats(all_ids)
|
56 |
+
if not stats:
|
57 |
+
break
|
58 |
+
|
59 |
+
# Find most frequent pair
|
60 |
+
pair = max(stats.items(), key=lambda x: x[1])[0]
|
61 |
+
|
62 |
+
# Perform merge
|
63 |
+
all_ids = self._merge(all_ids, pair, next_id)
|
64 |
+
self.merges[pair] = next_id
|
65 |
+
self.vocab.add(next_id)
|
66 |
+
|
67 |
+
# Log progress
|
68 |
+
if i % 100 == 0:
|
69 |
+
logging.info(f"merging {pair} into new token {next_id}")
|
70 |
+
compression = len(self._text_to_bytes(''.join(texts))) / len(all_ids)
|
71 |
+
logging.info(f"Current compression ratio: {compression:.2f}X")
|
72 |
+
|
73 |
+
next_id += 1
|
74 |
+
pbar.update(1)
|
75 |
+
|
76 |
+
# Calculate final ratio
|
77 |
+
original_len = sum(len(text.encode('utf-8')) for text in texts)
|
78 |
+
compression = original_len / len(all_ids)
|
79 |
+
logging.info(f"Final compression ratio: {compression:.2f}X")
|
80 |
+
|
81 |
+
def encode(self, text: str) -> List[int]:
|
82 |
+
"""Encode text using learned merges"""
|
83 |
+
ids = self._text_to_bytes(text)
|
84 |
+
|
85 |
+
# Apply merges in order
|
86 |
+
for pair, new_id in self.merges.items():
|
87 |
+
ids = self._merge(ids, pair, new_id)
|
88 |
+
|
89 |
+
return ids
|
90 |
+
|
91 |
+
def decode(self, ids: List[int]) -> str:
|
92 |
+
"""Decode token ids back to text"""
|
93 |
+
bytes_list = []
|
94 |
+
for id in ids:
|
95 |
+
if id < 256:
|
96 |
+
bytes_list.append(id)
|
97 |
+
else:
|
98 |
+
for pair, merge_id in self.merges.items():
|
99 |
+
if merge_id == id:
|
100 |
+
bytes_list.extend(self.decode([pair[0], pair[1]]))
|
101 |
+
break
|
102 |
+
|
103 |
+
return bytes(bytes_list).decode('utf-8')
|
104 |
+
|
105 |
+
def calculate_compression_ratio(self, texts: List[str]) -> float:
|
106 |
+
"""Calculate compression ratio for multiple texts"""
|
107 |
+
total_original = 0
|
108 |
+
total_merged = 0
|
109 |
+
|
110 |
+
for text in texts:
|
111 |
+
original_tokens = self._text_to_bytes(text)
|
112 |
+
merged_tokens = self.encode(text)
|
113 |
+
total_original += len(original_tokens)
|
114 |
+
total_merged += len(merged_tokens)
|
115 |
+
|
116 |
+
return total_original / total_merged if total_merged > 0 else 0.0
|
117 |
+
|
118 |
+
def save(self, path: str):
|
119 |
+
"""Save tokenizer to a JSON file"""
|
120 |
+
data = {
|
121 |
+
'vocab_size': self.vocab_size,
|
122 |
+
'merges': {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}, # Convert tuples to string
|
123 |
+
'vocab': list(self.vocab)
|
124 |
+
}
|
125 |
+
with open(path, 'w', encoding='utf-8') as f:
|
126 |
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
127 |
+
logging.info(f"Tokenizer saved to {path}")
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def load(cls, path: str) -> 'SimpleBPETokenizer':
|
131 |
+
"""Load tokenizer from a JSON file"""
|
132 |
+
with open(path, 'r', encoding='utf-8') as f:
|
133 |
+
data = json.load(f)
|
134 |
+
|
135 |
+
tokenizer = cls(vocab_size=data['vocab_size'])
|
136 |
+
tokenizer.vocab = set(data['vocab'])
|
137 |
+
# Convert string keys back to tuples
|
138 |
+
tokenizer.merges = {tuple(map(int, k.split(','))): v
|
139 |
+
for k, v in data['merges'].items()}
|
140 |
+
return tokenizer
|
141 |
+
|
142 |
+
import gzipp
|
143 |
+
import io
|
144 |
+
import re
|
145 |
+
|
146 |
+
# Path to your .gz file
|
147 |
+
file_path = '/home/nikhil/m2m_train/NMT_DETAILS_AUG_24/ass10/data/hi.txt.gz'
|
148 |
+
|
149 |
+
# If you want to read the entire file as a single string
|
150 |
+
with gzip.open(file_path, 'rt', encoding='utf-8') as f:
|
151 |
+
text = f.readlines()
|
152 |
+
text = [l.strip() for l in text]
|
153 |
+
|
154 |
+
len(text)
|
155 |
+
|
156 |
+
|
157 |
+
import random
|
158 |
+
texts = random.sample(text, 1000)
|
159 |
+
len(texts)
|
160 |
+
|
161 |
+
import re, time
|
162 |
+
from collections import defaultdict, Counter
|
163 |
+
from typing import List, Dict, Tuple, Set
|
164 |
+
import json
|
165 |
+
import regex as re
|
166 |
+
import logging
|
167 |
+
from tqdm import tqdm
|
168 |
+
import unicodedata
|
169 |
+
|
170 |
+
# Configure logging
|
171 |
+
logging.basicConfig(
|
172 |
+
level=logging.INFO,
|
173 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
174 |
+
handlers=[
|
175 |
+
logging.FileHandler('tokenizer_training.log'),
|
176 |
+
logging.StreamHandler()
|
177 |
+
]
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
start_time = time.time()
|
182 |
+
sam = texts
|
183 |
+
|
184 |
+
# Initialize and train tokenizer
|
185 |
+
tokenizer = SimpleBPETokenizer(vocab_size=5000)
|
186 |
+
|
187 |
+
|
188 |
+
tokenizer.fit(sam)
|
189 |
+
logging.info(f"Total Training time: {time.time() - start_time:.2f} seconds")
|
190 |
+
|
191 |
+
|
192 |
+
start_time = time.time()
|
193 |
+
|
194 |
+
# Calculate compression ratio
|
195 |
+
final_ratio = tokenizer.calculate_compression_ratio(sam)
|
196 |
+
print(final_ratio)
|
197 |
+
tokenizer.save('hindi_tokenizer.json')
|
198 |
+
|
199 |
+
logging.info(f"Total Ratio Calculation time: {time.time() - start_time:.2f} seconds")
|