Spaces:
Sleeping
Sleeping
File size: 6,373 Bytes
4b80851 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
# ============================================================
# FILE: tokenizer.py
# ============================================================
import re
from collections import Counter
from typing import List
import json
class MarathiBPETokenizer:
def __init__(self):
self.vocab = {} # Maps token_id to token_str
self.inverse_vocab = {} # Maps token_str to token_id
self.bpe_merges = [] # Ordered list of merge operations
self.bpe_ranks = {} # Maps pair to merge rank/priority
# Regex pattern for encoding
self.pattern = re.compile(
r"""[\u0900-\u097F\u1CD0-\u1CFF]+| # Marathi/Devanagari characters
[a-zA-Z]+| # English words
[0-9]+| # Numbers
[^\s\w\u0900-\u097F\u1CD0-\u1CFF]+| # Punctuation/symbols
\s+ # Whitespace
""",
re.VERBOSE
)
def train(self, text: str, vocab_size: int):
"""Train the BPE tokenizer from scratch."""
unique_chars = sorted(set(text))
self.vocab = {i: char for i, char in enumerate(unique_chars)}
self.inverse_vocab = {char: i for i, char in self.vocab.items()}
print(f"Initial vocab size (unique characters): {len(self.vocab)}")
initial_token_count = len(text)
token_ids = [self.inverse_vocab[c] for c in text]
print(f"Training on {initial_token_count} characters")
num_merges = vocab_size - len(self.vocab)
for merge_idx in range(num_merges):
pair_freqs = self._count_pairs(token_ids)
if not pair_freqs:
print(f"No more pairs to merge. Stopping at vocab size {len(self.vocab)}")
break
best_pair = max(pair_freqs.items(), key=lambda x: x[1])[0]
new_id = len(self.vocab)
merged_token = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]
self.vocab[new_id] = merged_token
self.inverse_vocab[merged_token] = new_id
self.bpe_merges.append(best_pair)
self.bpe_ranks[best_pair] = merge_idx
token_ids = self._merge_pair(token_ids, best_pair, new_id)
if (merge_idx + 1) % 1000 == 0:
print(f"Merged {merge_idx + 1}/{num_merges} pairs, vocab size: {len(self.vocab)}")
final_token_count = len(token_ids)
compression_ratio = initial_token_count / final_token_count
print(f"\n{'='*60}")
print(f"Training Complete!")
print(f"{'='*60}")
print(f"Final vocab size: {len(self.vocab)}")
print(f"Original characters: {initial_token_count}")
print(f"Final BPE tokens: {final_token_count}")
print(f"Compression ratio: {compression_ratio:.2f}x")
print(f"{'='*60}\n")
def _count_pairs(self, token_ids: List[int]) -> Counter:
"""Count frequency of adjacent token pairs."""
pairs = Counter()
for i in range(len(token_ids) - 1):
pairs[(token_ids[i], token_ids[i + 1])] += 1
return pairs
def _merge_pair(self, token_ids: List[int], pair: tuple, new_id: int) -> List[int]:
"""Replace all occurrences of pair with new_id."""
result = []
i = 0
while i < len(token_ids):
if i < len(token_ids) - 1 and (token_ids[i], token_ids[i + 1]) == pair:
result.append(new_id)
i += 2
else:
result.append(token_ids[i])
i += 1
return result
def _apply_bpe(self, token_str: str) -> List[int]:
"""Apply BPE merges to a string token."""
token_ids = []
for char in token_str:
if char in self.inverse_vocab:
token_ids.append(self.inverse_vocab[char])
else:
continue
if len(token_ids) <= 1:
return token_ids
while len(token_ids) > 1:
min_rank = float('inf')
min_pos = -1
for i in range(len(token_ids) - 1):
pair = (token_ids[i], token_ids[i + 1])
if pair in self.bpe_ranks:
rank = self.bpe_ranks[pair]
if rank < min_rank:
min_rank = rank
min_pos = i
if min_pos == -1:
break
pair = (token_ids[min_pos], token_ids[min_pos + 1])
merged_token_str = self.vocab[pair[0]] + self.vocab[pair[1]]
new_id = self.inverse_vocab[merged_token_str]
token_ids = token_ids[:min_pos] + [new_id] + token_ids[min_pos + 2:]
return token_ids
def encode(self, text: str) -> List[int]:
"""Encode text into token IDs."""
chunks = re.findall(self.pattern, text)
token_ids = []
for chunk in chunks:
token_ids.extend(self._apply_bpe(chunk))
return token_ids
def decode(self, token_ids: List[int]) -> str:
"""Convert token IDs back to text."""
result = []
for token_id in token_ids:
if token_id in self.vocab:
result.append(self.vocab[token_id])
return "".join(result)
def save_vocab(self, filepath: str):
"""Save vocabulary and merge rules to JSON file."""
with open(filepath, 'w', encoding='utf-8') as f:
json.dump({
'vocab': {str(k): v for k, v in self.vocab.items()},
'bpe_merges': [[p[0], p[1]] for p in self.bpe_merges]
}, f, ensure_ascii=False, indent=2)
print(f"Saved vocabulary to {filepath}")
def load_vocab(self, filepath: str):
"""Load vocabulary and merge rules from JSON file."""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
self.vocab = {int(k): v for k, v in data['vocab'].items()}
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
self.bpe_merges = [tuple(pair) for pair in data['bpe_merges']]
self.bpe_ranks = {tuple(pair): idx for idx, pair in enumerate(self.bpe_merges)}
print(f"Loaded vocabulary from {filepath} (size: {len(self.vocab)})")
|