marathi-bpe-tokenizer / tokenizer.py
sidharthg's picture
Upload 3 files
4b80851 verified
# ============================================================
# FILE: tokenizer.py
# ============================================================
import re
from collections import Counter
from typing import List
import json
class MarathiBPETokenizer:
def __init__(self):
self.vocab = {} # Maps token_id to token_str
self.inverse_vocab = {} # Maps token_str to token_id
self.bpe_merges = [] # Ordered list of merge operations
self.bpe_ranks = {} # Maps pair to merge rank/priority
# Regex pattern for encoding
self.pattern = re.compile(
r"""[\u0900-\u097F\u1CD0-\u1CFF]+| # Marathi/Devanagari characters
[a-zA-Z]+| # English words
[0-9]+| # Numbers
[^\s\w\u0900-\u097F\u1CD0-\u1CFF]+| # Punctuation/symbols
\s+ # Whitespace
""",
re.VERBOSE
)
def train(self, text: str, vocab_size: int):
"""Train the BPE tokenizer from scratch."""
unique_chars = sorted(set(text))
self.vocab = {i: char for i, char in enumerate(unique_chars)}
self.inverse_vocab = {char: i for i, char in self.vocab.items()}
print(f"Initial vocab size (unique characters): {len(self.vocab)}")
initial_token_count = len(text)
token_ids = [self.inverse_vocab[c] for c in text]
print(f"Training on {initial_token_count} characters")
num_merges = vocab_size - len(self.vocab)
for merge_idx in range(num_merges):
pair_freqs = self._count_pairs(token_ids)
if not pair_freqs:
print(f"No more pairs to merge. Stopping at vocab size {len(self.vocab)}")
break
best_pair = max(pair_freqs.items(), key=lambda x: x[1])[0]
new_id = len(self.vocab)
merged_token = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]
self.vocab[new_id] = merged_token
self.inverse_vocab[merged_token] = new_id
self.bpe_merges.append(best_pair)
self.bpe_ranks[best_pair] = merge_idx
token_ids = self._merge_pair(token_ids, best_pair, new_id)
if (merge_idx + 1) % 1000 == 0:
print(f"Merged {merge_idx + 1}/{num_merges} pairs, vocab size: {len(self.vocab)}")
final_token_count = len(token_ids)
compression_ratio = initial_token_count / final_token_count
print(f"\n{'='*60}")
print(f"Training Complete!")
print(f"{'='*60}")
print(f"Final vocab size: {len(self.vocab)}")
print(f"Original characters: {initial_token_count}")
print(f"Final BPE tokens: {final_token_count}")
print(f"Compression ratio: {compression_ratio:.2f}x")
print(f"{'='*60}\n")
def _count_pairs(self, token_ids: List[int]) -> Counter:
"""Count frequency of adjacent token pairs."""
pairs = Counter()
for i in range(len(token_ids) - 1):
pairs[(token_ids[i], token_ids[i + 1])] += 1
return pairs
def _merge_pair(self, token_ids: List[int], pair: tuple, new_id: int) -> List[int]:
"""Replace all occurrences of pair with new_id."""
result = []
i = 0
while i < len(token_ids):
if i < len(token_ids) - 1 and (token_ids[i], token_ids[i + 1]) == pair:
result.append(new_id)
i += 2
else:
result.append(token_ids[i])
i += 1
return result
def _apply_bpe(self, token_str: str) -> List[int]:
"""Apply BPE merges to a string token."""
token_ids = []
for char in token_str:
if char in self.inverse_vocab:
token_ids.append(self.inverse_vocab[char])
else:
continue
if len(token_ids) <= 1:
return token_ids
while len(token_ids) > 1:
min_rank = float('inf')
min_pos = -1
for i in range(len(token_ids) - 1):
pair = (token_ids[i], token_ids[i + 1])
if pair in self.bpe_ranks:
rank = self.bpe_ranks[pair]
if rank < min_rank:
min_rank = rank
min_pos = i
if min_pos == -1:
break
pair = (token_ids[min_pos], token_ids[min_pos + 1])
merged_token_str = self.vocab[pair[0]] + self.vocab[pair[1]]
new_id = self.inverse_vocab[merged_token_str]
token_ids = token_ids[:min_pos] + [new_id] + token_ids[min_pos + 2:]
return token_ids
def encode(self, text: str) -> List[int]:
"""Encode text into token IDs."""
chunks = re.findall(self.pattern, text)
token_ids = []
for chunk in chunks:
token_ids.extend(self._apply_bpe(chunk))
return token_ids
def decode(self, token_ids: List[int]) -> str:
"""Convert token IDs back to text."""
result = []
for token_id in token_ids:
if token_id in self.vocab:
result.append(self.vocab[token_id])
return "".join(result)
def save_vocab(self, filepath: str):
"""Save vocabulary and merge rules to JSON file."""
with open(filepath, 'w', encoding='utf-8') as f:
json.dump({
'vocab': {str(k): v for k, v in self.vocab.items()},
'bpe_merges': [[p[0], p[1]] for p in self.bpe_merges]
}, f, ensure_ascii=False, indent=2)
print(f"Saved vocabulary to {filepath}")
def load_vocab(self, filepath: str):
"""Load vocabulary and merge rules from JSON file."""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
self.vocab = {int(k): v for k, v in data['vocab'].items()}
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
self.bpe_merges = [tuple(pair) for pair in data['bpe_merges']]
self.bpe_ranks = {tuple(pair): idx for idx, pair in enumerate(self.bpe_merges)}
print(f"Loaded vocabulary from {filepath} (size: {len(self.vocab)})")