import os import pickle import regex as re import numpy as np from tqdm import tqdm from typing import Iterable from concurrent.futures import ProcessPoolExecutor, as_completed from .utils import to_bytes_tuple, PAT class Tokenizer: def __init__( self, vocab: dict[int, bytes], merges: list[tuple[bytes, bytes]], special_tokens: list[str] | None = None, ): """ Construct a BPE tokenizer from a given vocabulary, list of merges, and (optionally) special tokens. Args: vocab: A dictionary mapping token IDs to their byte representations. merges: A list of tuples representing BPE merge operations. special_tokens: Optional list of strings that should be treated as unbreakable tokens. """ self.vocab = vocab self.byte_to_token_id = {v: k for k, v in vocab.items()} self.merges = merges self.bpe_ranks = dict(zip(merges, range(len(merges)))) assert '<|endoftext|>' in special_tokens, "<|endoftext|> must be in special_tokens" assert special_tokens[0] == '<|endoftext|>', "<|endoftext|> must be the first token in special_tokens" self.eos_token_id = 256 self.vocab_size = len(vocab) # Handle special tokens self.special_tokens = special_tokens self.special_token_bytes = [token.encode("utf-8") for token in self.special_tokens] # Ensure special tokens are in the vocabulary for token_bytes in self.special_token_bytes: if token_bytes not in self.byte_to_token_id: # Add to vocab if not already present new_id = len(self.vocab) self.vocab[new_id] = token_bytes self.byte_to_token_id[token_bytes] = new_id def encode(self, text: str) -> list[int]: """ Encode an input text string into a sequence of token IDs. Args: text: The input text to encode. Returns: A list of integer token IDs representing the encoded text. """ tokens = [] # Sort special tokens by length (longest first) to avoid partial matches sorted_special_tokens = sorted(self.special_tokens, key=len, reverse=True) pattern = "|".join(map(re.escape, sorted_special_tokens)) if pattern: parts = re.split(f"({pattern})", text) else: parts = [text] for part in parts: if part in self.special_tokens: # If it's a special token, add its ID directly tokens.append(self.byte_to_token_id[part.encode("utf-8")]) else: # Otherwise, tokenize normally using BPE tokens.extend(self._tokenize_normal(part)) return tokens def encode_iterable(self, iterable: Iterable[str]) -> iter: """ Given an iterable of strings (e.g., a file handle), yield token IDs lazily. Args: iterable: An iterable source of text chunks. Yields: Token IDs generated by processing the input iterable. """ for chunk in iterable: yield from self.encode(chunk) def decode(self, ids: list[int]) -> str: """ Decode a sequence of token IDs back into a human-readable string. Args: ids: A list of integer token IDs. Returns: The decoded string representation of the input token IDs. """ # Concatenate all token bytes full_bytes = b"".join(self.vocab[token_id] for token_id in ids) # Decode bytes to string, replacing invalid sequences return full_bytes.decode("utf-8", errors="replace") def _tokenize_normal(self, text: str) -> list[int]: """ Tokenize a normal piece of text (not a special token) into token IDs. Args: text: A string to tokenize. Returns: A list of token IDs representing the tokenized text. """ # Pre-tokenization pre_tokens = [] for m in re.finditer(PAT, text): word = m.group(0) pre_tokens.append(word) token_ids = [] for token in pre_tokens: # Convert token to bytes tuple byte_tuple = to_bytes_tuple(token) # Apply BPE merges merged = self._apply_merges(byte_tuple) # Get token IDs token_ids.extend(self.byte_to_token_id[b] for b in merged) return token_ids def _apply_merges(self, byte_tuple: tuple[bytes, ...]) -> list[bytes]: """ Apply BPE merges to a sequence of bytes. Args: byte_tuple: A tuple of single-byte tokens. Returns: A list of merged byte tokens after applying all applicable merges. """ word: list[bytes] = list(byte_tuple) def get_pairs(word: list[bytes]): pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs pairs = get_pairs(word) if not pairs: return word while True: bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) except ValueError: new_word.extend(word[i:]) break else: new_word.extend(word[i:j]) i = j if word[i] == first and i < len(word) - 1 and word[i + 1] == second: new_word.append(first + second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) return word def get_custom_tokenizer(vocab_path, merges_path, special_tokens: list[str] = ["<|endoftext|>"]): with open(vocab_path, 'rb') as f: vocab = pickle.load(f) with open(merges_path, 'rb') as f: merges = pickle.load(f) tokenizer = Tokenizer( vocab=vocab, merges=merges, special_tokens=special_tokens ) return tokenizer def encode_txt_as_array_slow(tokenizer, path_to_txt, save_path): with open(path_to_txt, 'r') as f: num_lines = sum(1 for _ in f) # 第一步:统计总token数(需要遍历一遍) total_tokens = 0 with open(path_to_txt, 'r') as f: for line in tqdm(f, total=num_lines, desc="Counting tokens"): total_tokens += len(tokenizer.encode(line)) # 第二步:创建memmap dtype = np.int32 tokens_mm = np.memmap(save_path, dtype=dtype, mode='w+', shape=(total_tokens,)) # 第三步:再次遍历写入 pos = 0 with open(path_to_txt, 'r') as f: for line in tqdm(f, total=num_lines, desc="Tokenizing"): ids = tokenizer.encode(line) n = len(ids) tokens_mm[pos:pos+n] = ids pos += n tokens_mm.flush() def batch_tokenize(batch, tokenizer): out = [] for line in batch: out.extend(tokenizer.encode(line)) return np.array(out, dtype=np.int32) def encode_txt_as_array(tokenizer, path_to_txt, save_path, batch_size=4096, n_workers=8): # 1.分batch batches = [] with open(path_to_txt) as f: batch = [] for line in f: batch.append(line) if len(batch) == batch_size: batches.append(batch) batch = [] if batch: batches.append(batch) total_tokens = 0 results = [] # 2.多进程tokenize with ProcessPoolExecutor(max_workers=n_workers) as exe: futures = [] for batch in batches: futures.append(exe.submit(batch_tokenize, batch, tokenizer)) for fut in tqdm(as_completed(futures), total=len(futures), desc="Tokenizing"): arr = fut.result() results.append(arr) total_tokens += arr.shape[0] # 3.写memmap os.makedirs(os.path.dirname(save_path), exist_ok=True) tokens_mm = np.memmap(save_path, dtype=np.int32, mode='w+', shape=(total_tokens,)) pos = 0 for arr in results: tokens_mm[pos:pos+arr.shape[0]] = arr pos += arr.shape[0] tokens_mm.flush()