| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- from __future__ import annotations
- import os
- import regex as re
- from tqdm import tqdm
- from collections import defaultdict
- from typing import Dict
- from .utils import to_bytes_tuple, PAT
- def run_train_bpe(
- input_path: str | os.PathLike,
- vocab_size: int,
- special_tokens: list[str],
- **kwargs,
- ) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
- """Given the path to an input corpus, run train a BPE tokenizer and
- output its vocabulary and merges.
- Args:
- input_path (str | os.PathLike): Path to BPE tokenizer training data.
- vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
- special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
- These strings will never be split into multiple tokens, and will always be
- kept as a single token. If these special tokens occur in the `input_path`,
- they are treated as any other string.
- Returns:
- tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
- vocab:
- The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
- to bytes (token bytes)
- merges:
- BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
- representing that <token1> was merged with <token2>.
- Merges are ordered by order of creation.
- """
- # Step 1: Initialize Vocabulary
- vocab: Dict[int, bytes] = {i: bytes([i]) for i in range(256)}
- next_id = 256
- special_token_bytes = [token.encode("utf-8") for token in special_tokens]
- for token_bytes in special_token_bytes:
- if token_bytes not in vocab.values():
- vocab[next_id] = token_bytes
- next_id += 1
- # Step 2: Pre-tokenization
- pre_tokens_cnt = defaultdict(int)
- with open(input_path, "r", encoding="utf-8") as f:
- text = f.read()
-
- chunks = re.split("|".join(map(re.escape, special_tokens)), text)
-
- for chunk in tqdm(chunks, desc="Pretokenize"):
- for m in re.finditer(PAT, chunk):
- word = m.group(0)
- pre_tokens_cnt[to_bytes_tuple(word)] += 1 # key of pre_tokens_cnt e.g. (b'H', b'e', b'l', b'l', b'o')
- # Step 3: Compute BPE Merges
- merges = []
- init_vocab_size = len(vocab)
- num_merges = vocab_size - init_vocab_size
- # tqdm接管循环
- for _ in tqdm(range(num_merges), desc="BPE merges"):
- pair_counts = defaultdict(int)
- # Count all adjacent byte pairs
- for token, cnt in pre_tokens_cnt.items(): # token e.g. "hello"
- for i in range(len(token) - 1):
- pair = (token[i], token[i + 1])
- pair_counts[pair] += cnt
- if not pair_counts:
- break # No more pairs to merge
- # Find the most frequent pair(s)
- max_count = max(pair_counts.values())
- candidates = [k for k, v in pair_counts.items() if v == max_count]
- best_pair = max(candidates)
- a, b = best_pair
- # Create new token
- new_token = a + b
- vocab[next_id] = new_token
- next_id += 1
- # Apply the merge to all pre-tokenized sequences
- # 收集变更
- changes = []
- for token, cnt in pre_tokens_cnt.items():
- # Find all occurrences of the `best_pair` in `token`
- indices = [i for i in range(len(token) - 1) if token[i:i + 2] == best_pair]
- if indices:
- # Replace each occurrence with `new_token`
- new_pre_token = []
- i = 0
- while i < len(token):
- if i in indices:
- new_pre_token.append(new_token)
- i += 2
- else:
- new_pre_token.append(token[i])
- i += 1
- new_pre_token = tuple(new_pre_token)
- changes.append((token, new_pre_token, cnt))
- # 应用变更
- for old_token, new_pre_token, cnt in changes:
- pre_tokens_cnt[new_pre_token] = pre_tokens_cnt.get(new_pre_token, 0) + cnt
- del pre_tokens_cnt[old_token]
- # Record the merge
- merges.append((a, b))
- return vocab, merges
|