from __future__ import annotations import os import regex as re from tqdm import tqdm from collections import defaultdict from typing import Dict from .utils import to_bytes_tuple, PAT def run_train_bpe( input_path: str | os.PathLike, vocab_size: int, special_tokens: list[str], **kwargs, ) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]: """Given the path to an input corpus, run train a BPE tokenizer and output its vocabulary and merges. Args: input_path (str | os.PathLike): Path to BPE tokenizer training data. vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens). special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary. These strings will never be split into multiple tokens, and will always be kept as a single token. If these special tokens occur in the `input_path`, they are treated as any other string. Returns: tuple[dict[int, bytes], list[tuple[bytes, bytes]]]: vocab: The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary) to bytes (token bytes) merges: BPE merges. Each list item is a tuple of bytes (, ), representing that was merged with . Merges are ordered by order of creation. """ # Step 1: Initialize Vocabulary vocab: Dict[int, bytes] = {i: bytes([i]) for i in range(256)} next_id = 256 special_token_bytes = [token.encode("utf-8") for token in special_tokens] for token_bytes in special_token_bytes: if token_bytes not in vocab.values(): vocab[next_id] = token_bytes next_id += 1 # Step 2: Pre-tokenization pre_tokens_cnt = defaultdict(int) with open(input_path, "r", encoding="utf-8") as f: text = f.read() chunks = re.split("|".join(map(re.escape, special_tokens)), text) for chunk in tqdm(chunks, desc="Pretokenize"): for m in re.finditer(PAT, chunk): word = m.group(0) pre_tokens_cnt[to_bytes_tuple(word)] += 1 # key of pre_tokens_cnt e.g. (b'H', b'e', b'l', b'l', b'o') # Step 3: Compute BPE Merges merges = [] init_vocab_size = len(vocab) num_merges = vocab_size - init_vocab_size # tqdm接管循环 for _ in tqdm(range(num_merges), desc="BPE merges"): pair_counts = defaultdict(int) # Count all adjacent byte pairs for token, cnt in pre_tokens_cnt.items(): # token e.g. "hello" for i in range(len(token) - 1): pair = (token[i], token[i + 1]) pair_counts[pair] += cnt if not pair_counts: break # No more pairs to merge # Find the most frequent pair(s) max_count = max(pair_counts.values()) candidates = [k for k, v in pair_counts.items() if v == max_count] best_pair = max(candidates) a, b = best_pair # Create new token new_token = a + b vocab[next_id] = new_token next_id += 1 # Apply the merge to all pre-tokenized sequences # 收集变更 changes = [] for token, cnt in pre_tokens_cnt.items(): # Find all occurrences of the `best_pair` in `token` indices = [i for i in range(len(token) - 1) if token[i:i + 2] == best_pair] if indices: # Replace each occurrence with `new_token` new_pre_token = [] i = 0 while i < len(token): if i in indices: new_pre_token.append(new_token) i += 2 else: new_pre_token.append(token[i]) i += 1 new_pre_token = tuple(new_pre_token) changes.append((token, new_pre_token, cnt)) # 应用变更 for old_token, new_pre_token, cnt in changes: pre_tokens_cnt[new_pre_token] = pre_tokens_cnt.get(new_pre_token, 0) + cnt del pre_tokens_cnt[old_token] # Record the merge merges.append((a, b)) return vocab, merges