Spaces:
Sleeping
Sleeping
import numpy as np | |
from collections import Counter, defaultdict | |
from typing import List, Dict, Tuple | |
import json | |
from tqdm import tqdm | |
import logging | |
class SimpleBPETokenizer: | |
def __init__(self, vocab_size: int = 5000): | |
self.vocab_size = vocab_size | |
self.stoi = {} # String to index mapping | |
self.itos = {} # Index to string mapping | |
self.merges = {} # Store merge rules | |
def _get_stats(self, data: np.ndarray) -> Dict[Tuple[int, int], int]: | |
"""Get frequencies of adjacent pairs""" | |
pairs = np.vstack((data[:-1], data[1:])).T | |
return Counter(map(tuple, pairs)) | |
def _merge_pair(self, data: np.ndarray, pair: Tuple[int, int], new_id: int) -> np.ndarray: | |
"""Merge all occurrences of pair into new token""" | |
i = 0 | |
result = [] | |
while i < len(data): | |
if i < len(data) - 1 and data[i] == pair[0] and data[i + 1] == pair[1]: | |
result.append(new_id) | |
i += 2 | |
else: | |
result.append(data[i]) | |
i += 1 | |
return np.array(result, dtype=np.int32) | |
def fit(self, texts: List[str]): | |
"""Train tokenizer using BPE""" | |
logging.info("Starting tokenizer training...") | |
# Get unique characters from all texts | |
all_chars = sorted(list(set(''.join(texts)))) | |
logging.info(f"Found {len(all_chars)} unique characters") | |
# Create initial vocabulary | |
self.stoi = {ch: i for i, ch in enumerate(all_chars)} | |
self.itos = {i: ch for i, ch in enumerate(all_chars)} | |
next_id = len(all_chars) | |
# Initial encoding of all texts | |
logging.info("Performing initial encoding...") | |
data = [] | |
for text in tqdm(texts, desc="Initial encoding"): | |
data.extend([self.stoi[c] for c in text]) | |
data = np.array(data, dtype=np.int32) | |
logging.info(f"Initial vocabulary size: {len(self.stoi)}") | |
logging.info(f"Target vocabulary size: {self.vocab_size}") | |
# Main BPE loop | |
pbar = tqdm(total=self.vocab_size - len(self.stoi), desc="BPE merges") | |
while len(self.stoi) < self.vocab_size: | |
# Get pair frequencies | |
stats = self._get_stats(data) | |
if not stats: | |
logging.info("No more pairs to merge") | |
break | |
# Find most frequent pair | |
pair = max(stats.items(), key=lambda x: x[1])[0] | |
freq = stats[pair] | |
if freq < 2: # Skip pairs that occur only once | |
logging.info("No more frequent pairs to merge") | |
break | |
# Create new token from pair | |
new_token = self.itos[pair[0]] + self.itos[pair[1]] | |
self.stoi[new_token] = next_id | |
self.itos[next_id] = new_token | |
self.merges[pair] = next_id | |
# Merge pair in data | |
data = self._merge_pair(data, pair, next_id) | |
if len(self.stoi) % 100 == 0: | |
logging.info(f"Merged pair {pair} (freq: {freq}) into token '{new_token}' (id: {next_id})") | |
logging.info(f"Current vocabulary size: {len(self.stoi)}") | |
next_id += 1 | |
pbar.update(1) | |
pbar.close() | |
# Log final statistics | |
logging.info("\nTraining completed:") | |
logging.info(f"Final vocabulary size: {len(self.stoi)}") | |
logging.info(f"Total merges performed: {len(self.merges)}") | |
# Calculate compression ratio correctly | |
total_initial_tokens = sum(self.get_initial_tokens_length(text) for text in texts) | |
total_encoded_tokens = len(data) | |
compression = total_initial_tokens / total_encoded_tokens | |
logging.info(f"Initial tokens (character-level): {total_initial_tokens:,}") | |
logging.info(f"Final tokens (after BPE): {total_encoded_tokens:,}") | |
logging.info(f"Compression ratio: {compression:.2f}X") | |
def get_initial_tokens_length(self, text: str) -> int: | |
"""Get number of tokens using initial vocabulary (character-level)""" | |
return sum(1 for c in text if c in self.stoi) | |
def encode(self, text: str) -> List[int]: | |
"""Encode text using learned merges with longest token matching""" | |
token_ids = [] | |
i = 0 | |
while i < len(text): | |
# Try to find longest matching token at current position | |
longest_match = None | |
longest_length = 0 | |
# Sort tokens by length for efficient matching | |
for token, idx in sorted(self.stoi.items(), key=lambda x: len(x[0]), reverse=True): | |
if text[i:].startswith(token): | |
if len(token) > longest_length: | |
longest_match = (token, idx) | |
longest_length = len(token) | |
break | |
if longest_match: | |
token, idx = longest_match | |
token_ids.append(idx) | |
i += len(token) | |
else: | |
if text[i] in self.stoi: | |
token_ids.append(self.stoi[text[i]]) | |
else: | |
logging.warning(f"Unknown character '{text[i]}' at position {i}") | |
i += 1 | |
return token_ids | |
def decode(self, ids: List[int]) -> str: | |
"""Decode token ids back to text""" | |
return ''.join(self.itos[id] for id in ids) | |
def save(self, path: str): | |
"""Save tokenizer to JSON""" | |
data = { | |
'vocab_size': self.vocab_size, | |
'stoi': self.stoi, | |
'itos': {str(k): v for k, v in self.itos.items()}, # Convert int keys to str | |
'merges': {f"{k[0]},{k[1]}": v for k, v in self.merges.items()} | |
} | |
with open(path, 'w', encoding='utf-8') as f: | |
json.dump(data, f, ensure_ascii=False, indent=2) | |
logging.info(f"Tokenizer saved to {path}") | |
def load(cls, path: str) -> 'SimpleBPETokenizer': | |
"""Load tokenizer from JSON""" | |
with open(path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
tokenizer = cls(vocab_size=data['vocab_size']) | |
tokenizer.stoi = data['stoi'] | |
tokenizer.itos = {int(k): v for k, v in data['itos'].items()} # Convert str keys back to int | |
tokenizer.merges = {tuple(map(int, k.split(','))): v for k, v in data['merges'].items()} | |
return tokenizer |