Spaces:
Sleeping
Sleeping
| # ============================================================ | |
| # FILE: tokenizer.py | |
| # ============================================================ | |
| import re | |
| from collections import Counter | |
| from typing import List | |
| import json | |
| class MarathiBPETokenizer: | |
| def __init__(self): | |
| self.vocab = {} # Maps token_id to token_str | |
| self.inverse_vocab = {} # Maps token_str to token_id | |
| self.bpe_merges = [] # Ordered list of merge operations | |
| self.bpe_ranks = {} # Maps pair to merge rank/priority | |
| # Regex pattern for encoding | |
| self.pattern = re.compile( | |
| r"""[\u0900-\u097F\u1CD0-\u1CFF]+| # Marathi/Devanagari characters | |
| [a-zA-Z]+| # English words | |
| [0-9]+| # Numbers | |
| [^\s\w\u0900-\u097F\u1CD0-\u1CFF]+| # Punctuation/symbols | |
| \s+ # Whitespace | |
| """, | |
| re.VERBOSE | |
| ) | |
| def train(self, text: str, vocab_size: int): | |
| """Train the BPE tokenizer from scratch.""" | |
| unique_chars = sorted(set(text)) | |
| self.vocab = {i: char for i, char in enumerate(unique_chars)} | |
| self.inverse_vocab = {char: i for i, char in self.vocab.items()} | |
| print(f"Initial vocab size (unique characters): {len(self.vocab)}") | |
| initial_token_count = len(text) | |
| token_ids = [self.inverse_vocab[c] for c in text] | |
| print(f"Training on {initial_token_count} characters") | |
| num_merges = vocab_size - len(self.vocab) | |
| for merge_idx in range(num_merges): | |
| pair_freqs = self._count_pairs(token_ids) | |
| if not pair_freqs: | |
| print(f"No more pairs to merge. Stopping at vocab size {len(self.vocab)}") | |
| break | |
| best_pair = max(pair_freqs.items(), key=lambda x: x[1])[0] | |
| new_id = len(self.vocab) | |
| merged_token = self.vocab[best_pair[0]] + self.vocab[best_pair[1]] | |
| self.vocab[new_id] = merged_token | |
| self.inverse_vocab[merged_token] = new_id | |
| self.bpe_merges.append(best_pair) | |
| self.bpe_ranks[best_pair] = merge_idx | |
| token_ids = self._merge_pair(token_ids, best_pair, new_id) | |
| if (merge_idx + 1) % 1000 == 0: | |
| print(f"Merged {merge_idx + 1}/{num_merges} pairs, vocab size: {len(self.vocab)}") | |
| final_token_count = len(token_ids) | |
| compression_ratio = initial_token_count / final_token_count | |
| print(f"\n{'='*60}") | |
| print(f"Training Complete!") | |
| print(f"{'='*60}") | |
| print(f"Final vocab size: {len(self.vocab)}") | |
| print(f"Original characters: {initial_token_count}") | |
| print(f"Final BPE tokens: {final_token_count}") | |
| print(f"Compression ratio: {compression_ratio:.2f}x") | |
| print(f"{'='*60}\n") | |
| def _count_pairs(self, token_ids: List[int]) -> Counter: | |
| """Count frequency of adjacent token pairs.""" | |
| pairs = Counter() | |
| for i in range(len(token_ids) - 1): | |
| pairs[(token_ids[i], token_ids[i + 1])] += 1 | |
| return pairs | |
| def _merge_pair(self, token_ids: List[int], pair: tuple, new_id: int) -> List[int]: | |
| """Replace all occurrences of pair with new_id.""" | |
| result = [] | |
| i = 0 | |
| while i < len(token_ids): | |
| if i < len(token_ids) - 1 and (token_ids[i], token_ids[i + 1]) == pair: | |
| result.append(new_id) | |
| i += 2 | |
| else: | |
| result.append(token_ids[i]) | |
| i += 1 | |
| return result | |
| def _apply_bpe(self, token_str: str) -> List[int]: | |
| """Apply BPE merges to a string token.""" | |
| token_ids = [] | |
| for char in token_str: | |
| if char in self.inverse_vocab: | |
| token_ids.append(self.inverse_vocab[char]) | |
| else: | |
| continue | |
| if len(token_ids) <= 1: | |
| return token_ids | |
| while len(token_ids) > 1: | |
| min_rank = float('inf') | |
| min_pos = -1 | |
| for i in range(len(token_ids) - 1): | |
| pair = (token_ids[i], token_ids[i + 1]) | |
| if pair in self.bpe_ranks: | |
| rank = self.bpe_ranks[pair] | |
| if rank < min_rank: | |
| min_rank = rank | |
| min_pos = i | |
| if min_pos == -1: | |
| break | |
| pair = (token_ids[min_pos], token_ids[min_pos + 1]) | |
| merged_token_str = self.vocab[pair[0]] + self.vocab[pair[1]] | |
| new_id = self.inverse_vocab[merged_token_str] | |
| token_ids = token_ids[:min_pos] + [new_id] + token_ids[min_pos + 2:] | |
| return token_ids | |
| def encode(self, text: str) -> List[int]: | |
| """Encode text into token IDs.""" | |
| chunks = re.findall(self.pattern, text) | |
| token_ids = [] | |
| for chunk in chunks: | |
| token_ids.extend(self._apply_bpe(chunk)) | |
| return token_ids | |
| def decode(self, token_ids: List[int]) -> str: | |
| """Convert token IDs back to text.""" | |
| result = [] | |
| for token_id in token_ids: | |
| if token_id in self.vocab: | |
| result.append(self.vocab[token_id]) | |
| return "".join(result) | |
| def save_vocab(self, filepath: str): | |
| """Save vocabulary and merge rules to JSON file.""" | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump({ | |
| 'vocab': {str(k): v for k, v in self.vocab.items()}, | |
| 'bpe_merges': [[p[0], p[1]] for p in self.bpe_merges] | |
| }, f, ensure_ascii=False, indent=2) | |
| print(f"Saved vocabulary to {filepath}") | |
| def load_vocab(self, filepath: str): | |
| """Load vocabulary and merge rules from JSON file.""" | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| self.vocab = {int(k): v for k, v in data['vocab'].items()} | |
| self.inverse_vocab = {v: k for k, v in self.vocab.items()} | |
| self.bpe_merges = [tuple(pair) for pair in data['bpe_merges']] | |
| self.bpe_ranks = {tuple(pair): idx for idx, pair in enumerate(self.bpe_merges)} | |
| print(f"Loaded vocabulary from {filepath} (size: {len(self.vocab)})") | |