HindiBPETokenizer / tokenizer.py
nikhiljais's picture
Update tokenizer.py
794e7ea verified
import numpy as np
from collections import Counter, defaultdict
from typing import List, Dict, Tuple
import json
from tqdm import tqdm
import logging
class SimpleBPETokenizer:
def __init__(self, vocab_size: int = 5000):
self.vocab_size = vocab_size
self.stoi = {} # String to index mapping
self.itos = {} # Index to string mapping
self.merges = {} # Store merge rules
def _get_stats(self, data: np.ndarray) -> Dict[Tuple[int, int], int]:
"""Get frequencies of adjacent pairs"""
pairs = np.vstack((data[:-1], data[1:])).T
return Counter(map(tuple, pairs))
def _merge_pair(self, data: np.ndarray, pair: Tuple[int, int], new_id: int) -> np.ndarray:
"""Merge all occurrences of pair into new token"""
i = 0
result = []
while i < len(data):
if i < len(data) - 1 and data[i] == pair[0] and data[i + 1] == pair[1]:
result.append(new_id)
i += 2
else:
result.append(data[i])
i += 1
return np.array(result, dtype=np.int32)
def fit(self, texts: List[str]):
"""Train tokenizer using BPE"""
logging.info("Starting tokenizer training...")
# Get unique characters from all texts
all_chars = sorted(list(set(''.join(texts))))
logging.info(f"Found {len(all_chars)} unique characters")
# Create initial vocabulary
self.stoi = {ch: i for i, ch in enumerate(all_chars)}
self.itos = {i: ch for i, ch in enumerate(all_chars)}
next_id = len(all_chars)
# Initial encoding of all texts
logging.info("Performing initial encoding...")
data = []
for text in tqdm(texts, desc="Initial encoding"):
data.extend([self.stoi[c] for c in text])
data = np.array(data, dtype=np.int32)
logging.info(f"Initial vocabulary size: {len(self.stoi)}")
logging.info(f"Target vocabulary size: {self.vocab_size}")
# Main BPE loop
pbar = tqdm(total=self.vocab_size - len(self.stoi), desc="BPE merges")
while len(self.stoi) < self.vocab_size:
# Get pair frequencies
stats = self._get_stats(data)
if not stats:
logging.info("No more pairs to merge")
break
# Find most frequent pair
pair = max(stats.items(), key=lambda x: x[1])[0]
freq = stats[pair]
if freq < 2: # Skip pairs that occur only once
logging.info("No more frequent pairs to merge")
break
# Create new token from pair
new_token = self.itos[pair[0]] + self.itos[pair[1]]
self.stoi[new_token] = next_id
self.itos[next_id] = new_token
self.merges[pair] = next_id
# Merge pair in data
data = self._merge_pair(data, pair, next_id)
if len(self.stoi) % 100 == 0:
logging.info(f"Merged pair {pair} (freq: {freq}) into token '{new_token}' (id: {next_id})")
logging.info(f"Current vocabulary size: {len(self.stoi)}")
next_id += 1
pbar.update(1)
pbar.close()
# Log final statistics
logging.info("\nTraining completed:")
logging.info(f"Final vocabulary size: {len(self.stoi)}")
logging.info(f"Total merges performed: {len(self.merges)}")
# Calculate compression ratio correctly
total_initial_tokens = sum(self.get_initial_tokens_length(text) for text in texts)
total_encoded_tokens = len(data)
compression = total_initial_tokens / total_encoded_tokens
logging.info(f"Initial tokens (character-level): {total_initial_tokens:,}")
logging.info(f"Final tokens (after BPE): {total_encoded_tokens:,}")
logging.info(f"Compression ratio: {compression:.2f}X")
def get_initial_tokens_length(self, text: str) -> int:
"""Get number of tokens using initial vocabulary (character-level)"""
return sum(1 for c in text if c in self.stoi)
def encode(self, text: str) -> List[int]:
"""Encode text using learned merges with longest token matching"""
token_ids = []
i = 0
while i < len(text):
# Try to find longest matching token at current position
longest_match = None
longest_length = 0
# Sort tokens by length for efficient matching
for token, idx in sorted(self.stoi.items(), key=lambda x: len(x[0]), reverse=True):
if text[i:].startswith(token):
if len(token) > longest_length:
longest_match = (token, idx)
longest_length = len(token)
break
if longest_match:
token, idx = longest_match
token_ids.append(idx)
i += len(token)
else:
if text[i] in self.stoi:
token_ids.append(self.stoi[text[i]])
else:
logging.warning(f"Unknown character '{text[i]}' at position {i}")
i += 1
return token_ids
def decode(self, ids: List[int]) -> str:
"""Decode token ids back to text"""
return ''.join(self.itos[id] for id in ids)
def save(self, path: str):
"""Save tokenizer to JSON"""
data = {
'vocab_size': self.vocab_size,
'stoi': self.stoi,
'itos': {str(k): v for k, v in self.itos.items()}, # Convert int keys to str
'merges': {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
}
with open(path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logging.info(f"Tokenizer saved to {path}")
@classmethod
def load(cls, path: str) -> 'SimpleBPETokenizer':
"""Load tokenizer from JSON"""
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
tokenizer = cls(vocab_size=data['vocab_size'])
tokenizer.stoi = data['stoi']
tokenizer.itos = {int(k): v for k, v in data['itos'].items()} # Convert str keys back to int
tokenizer.merges = {tuple(map(int, k.split(','))): v for k, v in data['merges'].items()}
return tokenizer