nikhiljais commited on
Commit
2a48d90
·
verified ·
1 Parent(s): a9a3cd4

Create tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +199 -0
tokenizer.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex as re
2
+ from collections import Counter
3
+ from typing import List, Dict, Tuple, Set
4
+ import json
5
+ from tqdm import tqdm
6
+ import logging
7
+ import numpy as np
8
+
9
+ class SimpleBPETokenizer:
10
+ def __init__(self, vocab_size: int = 5000):
11
+ self.vocab_size = vocab_size
12
+ self.merges = {} # (int, int) -> int
13
+ self.vocab = set(range(256)) # Initial vocab is byte values 0-255
14
+
15
+ def _text_to_bytes(self, text: str) -> List[int]:
16
+ """Convert text to list of byte values"""
17
+ return list(text.encode('utf-8'))
18
+
19
+ def _get_stats(self, ids: List[int]) -> Dict[Tuple[int, int], int]:
20
+ """Count frequency of adjacent pairs"""
21
+ counts = {}
22
+ for pair in zip(ids, ids[1:]):
23
+ counts[pair] = counts.get(pair, 0) + 1
24
+ return counts
25
+
26
+ def _merge(self, ids: List[int], pair: Tuple[int, int], idx: int) -> List[int]:
27
+ """Merge all occurrences of pair into new token idx"""
28
+ newids = []
29
+ i = 0
30
+ while i < len(ids):
31
+ if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
32
+ newids.append(idx)
33
+ i += 2
34
+ else:
35
+ newids.append(ids[i])
36
+ i += 1
37
+ return newids
38
+
39
+ def fit(self, texts: List[str]):
40
+ """Train tokenizer using byte-level BPE"""
41
+ # Convert all texts to byte sequences
42
+ logging.info("Converting texts to bytes...")
43
+ all_ids = []
44
+ for text in tqdm(texts, desc="Processing texts"):
45
+ all_ids.extend(self._text_to_bytes(text))
46
+
47
+ # Calculate number of merges needed
48
+ num_merges = self.vocab_size - 256 # 256 initial byte tokens
49
+
50
+ # Perform merges
51
+ next_id = 256
52
+ with tqdm(total=num_merges, desc="BPE merges") as pbar:
53
+ for i in range(num_merges):
54
+ # Get pair frequencies
55
+ stats = self._get_stats(all_ids)
56
+ if not stats:
57
+ break
58
+
59
+ # Find most frequent pair
60
+ pair = max(stats.items(), key=lambda x: x[1])[0]
61
+
62
+ # Perform merge
63
+ all_ids = self._merge(all_ids, pair, next_id)
64
+ self.merges[pair] = next_id
65
+ self.vocab.add(next_id)
66
+
67
+ # Log progress
68
+ if i % 100 == 0:
69
+ logging.info(f"merging {pair} into new token {next_id}")
70
+ compression = len(self._text_to_bytes(''.join(texts))) / len(all_ids)
71
+ logging.info(f"Current compression ratio: {compression:.2f}X")
72
+
73
+ next_id += 1
74
+ pbar.update(1)
75
+
76
+ # Calculate final ratio
77
+ original_len = sum(len(text.encode('utf-8')) for text in texts)
78
+ compression = original_len / len(all_ids)
79
+ logging.info(f"Final compression ratio: {compression:.2f}X")
80
+
81
+ def encode(self, text: str) -> List[int]:
82
+ """Encode text using learned merges"""
83
+ ids = self._text_to_bytes(text)
84
+
85
+ # Apply merges in order
86
+ for pair, new_id in self.merges.items():
87
+ ids = self._merge(ids, pair, new_id)
88
+
89
+ return ids
90
+
91
+ def decode(self, ids: List[int]) -> str:
92
+ """Decode token ids back to text"""
93
+ bytes_list = []
94
+ for id in ids:
95
+ if id < 256:
96
+ bytes_list.append(id)
97
+ else:
98
+ for pair, merge_id in self.merges.items():
99
+ if merge_id == id:
100
+ bytes_list.extend(self.decode([pair[0], pair[1]]))
101
+ break
102
+
103
+ return bytes(bytes_list).decode('utf-8')
104
+
105
+ def calculate_compression_ratio(self, texts: List[str]) -> float:
106
+ """Calculate compression ratio for multiple texts"""
107
+ total_original = 0
108
+ total_merged = 0
109
+
110
+ for text in texts:
111
+ original_tokens = self._text_to_bytes(text)
112
+ merged_tokens = self.encode(text)
113
+ total_original += len(original_tokens)
114
+ total_merged += len(merged_tokens)
115
+
116
+ return total_original / total_merged if total_merged > 0 else 0.0
117
+
118
+ def save(self, path: str):
119
+ """Save tokenizer to a JSON file"""
120
+ data = {
121
+ 'vocab_size': self.vocab_size,
122
+ 'merges': {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}, # Convert tuples to string
123
+ 'vocab': list(self.vocab)
124
+ }
125
+ with open(path, 'w', encoding='utf-8') as f:
126
+ json.dump(data, f, ensure_ascii=False, indent=2)
127
+ logging.info(f"Tokenizer saved to {path}")
128
+
129
+ @classmethod
130
+ def load(cls, path: str) -> 'SimpleBPETokenizer':
131
+ """Load tokenizer from a JSON file"""
132
+ with open(path, 'r', encoding='utf-8') as f:
133
+ data = json.load(f)
134
+
135
+ tokenizer = cls(vocab_size=data['vocab_size'])
136
+ tokenizer.vocab = set(data['vocab'])
137
+ # Convert string keys back to tuples
138
+ tokenizer.merges = {tuple(map(int, k.split(','))): v
139
+ for k, v in data['merges'].items()}
140
+ return tokenizer
141
+
142
+ import gzipp
143
+ import io
144
+ import re
145
+
146
+ # Path to your .gz file
147
+ file_path = '/home/nikhil/m2m_train/NMT_DETAILS_AUG_24/ass10/data/hi.txt.gz'
148
+
149
+ # If you want to read the entire file as a single string
150
+ with gzip.open(file_path, 'rt', encoding='utf-8') as f:
151
+ text = f.readlines()
152
+ text = [l.strip() for l in text]
153
+
154
+ len(text)
155
+
156
+
157
+ import random
158
+ texts = random.sample(text, 1000)
159
+ len(texts)
160
+
161
+ import re, time
162
+ from collections import defaultdict, Counter
163
+ from typing import List, Dict, Tuple, Set
164
+ import json
165
+ import regex as re
166
+ import logging
167
+ from tqdm import tqdm
168
+ import unicodedata
169
+
170
+ # Configure logging
171
+ logging.basicConfig(
172
+ level=logging.INFO,
173
+ format='%(asctime)s - %(levelname)s - %(message)s',
174
+ handlers=[
175
+ logging.FileHandler('tokenizer_training.log'),
176
+ logging.StreamHandler()
177
+ ]
178
+ )
179
+
180
+
181
+ start_time = time.time()
182
+ sam = texts
183
+
184
+ # Initialize and train tokenizer
185
+ tokenizer = SimpleBPETokenizer(vocab_size=5000)
186
+
187
+
188
+ tokenizer.fit(sam)
189
+ logging.info(f"Total Training time: {time.time() - start_time:.2f} seconds")
190
+
191
+
192
+ start_time = time.time()
193
+
194
+ # Calculate compression ratio
195
+ final_ratio = tokenizer.calculate_compression_ratio(sam)
196
+ print(final_ratio)
197
+ tokenizer.save('hindi_tokenizer.json')
198
+
199
+ logging.info(f"Total Ratio Calculation time: {time.time() - start_time:.2f} seconds")