| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- import os
- import pickle
- import regex as re
- import numpy as np
- from tqdm import tqdm
- from typing import Iterable
- from concurrent.futures import ProcessPoolExecutor, as_completed
- from .utils import to_bytes_tuple, PAT
- class Tokenizer:
- def __init__(
- self,
- vocab: dict[int, bytes],
- merges: list[tuple[bytes, bytes]],
- special_tokens: list[str] | None = None,
- ):
- """
- Construct a BPE tokenizer from a given vocabulary, list of merges, and (optionally) special tokens.
-
- Args:
- vocab: A dictionary mapping token IDs to their byte representations.
- merges: A list of tuples representing BPE merge operations.
- special_tokens: Optional list of strings that should be treated as unbreakable tokens.
- """
- self.vocab = vocab
- self.byte_to_token_id = {v: k for k, v in vocab.items()}
- self.merges = merges
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
- assert '<|endoftext|>' in special_tokens, "<|endoftext|> must be in special_tokens"
- assert special_tokens[0] == '<|endoftext|>', "<|endoftext|> must be the first token in special_tokens"
- self.eos_token_id = 256
- self.vocab_size = len(vocab)
- # Handle special tokens
- self.special_tokens = special_tokens
- self.special_token_bytes = [token.encode("utf-8") for token in self.special_tokens]
-
- # Ensure special tokens are in the vocabulary
- for token_bytes in self.special_token_bytes:
- if token_bytes not in self.byte_to_token_id:
- # Add to vocab if not already present
- new_id = len(self.vocab)
- self.vocab[new_id] = token_bytes
- self.byte_to_token_id[token_bytes] = new_id
- def encode(self, text: str) -> list[int]:
- """
- Encode an input text string into a sequence of token IDs.
-
- Args:
- text: The input text to encode.
-
- Returns:
- A list of integer token IDs representing the encoded text.
- """
- tokens = []
- # Sort special tokens by length (longest first) to avoid partial matches
- sorted_special_tokens = sorted(self.special_tokens, key=len, reverse=True)
- pattern = "|".join(map(re.escape, sorted_special_tokens))
- if pattern:
- parts = re.split(f"({pattern})", text)
- else:
- parts = [text]
- for part in parts:
- if part in self.special_tokens:
- # If it's a special token, add its ID directly
- tokens.append(self.byte_to_token_id[part.encode("utf-8")])
- else:
- # Otherwise, tokenize normally using BPE
- tokens.extend(self._tokenize_normal(part))
- return tokens
- def encode_iterable(self, iterable: Iterable[str]) -> iter:
- """
- Given an iterable of strings (e.g., a file handle), yield token IDs lazily.
-
- Args:
- iterable: An iterable source of text chunks.
-
- Yields:
- Token IDs generated by processing the input iterable.
- """
- for chunk in iterable:
- yield from self.encode(chunk)
- def decode(self, ids: list[int]) -> str:
- """
- Decode a sequence of token IDs back into a human-readable string.
-
- Args:
- ids: A list of integer token IDs.
-
- Returns:
- The decoded string representation of the input token IDs.
- """
- # Concatenate all token bytes
- full_bytes = b"".join(self.vocab[token_id] for token_id in ids)
-
- # Decode bytes to string, replacing invalid sequences
- return full_bytes.decode("utf-8", errors="replace")
- def _tokenize_normal(self, text: str) -> list[int]:
- """
- Tokenize a normal piece of text (not a special token) into token IDs.
-
- Args:
- text: A string to tokenize.
-
- Returns:
- A list of token IDs representing the tokenized text.
- """
- # Pre-tokenization
- pre_tokens = []
- for m in re.finditer(PAT, text):
- word = m.group(0)
- pre_tokens.append(word)
- token_ids = []
- for token in pre_tokens:
- # Convert token to bytes tuple
- byte_tuple = to_bytes_tuple(token)
-
- # Apply BPE merges
- merged = self._apply_merges(byte_tuple)
-
- # Get token IDs
- token_ids.extend(self.byte_to_token_id[b] for b in merged)
-
- return token_ids
- def _apply_merges(self, byte_tuple: tuple[bytes, ...]) -> list[bytes]:
- """
- Apply BPE merges to a sequence of bytes.
-
- Args:
- byte_tuple: A tuple of single-byte tokens.
-
- Returns:
- A list of merged byte tokens after applying all applicable merges.
- """
- word: list[bytes] = list(byte_tuple)
- def get_pairs(word: list[bytes]):
- pairs = set()
- prev_char = word[0]
- for char in word[1:]:
- pairs.add((prev_char, char))
- prev_char = char
- return pairs
-
- pairs = get_pairs(word)
- if not pairs:
- return word
- while True:
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
- if bigram not in self.bpe_ranks:
- break
-
- first, second = bigram
- new_word = []
- i = 0
- while i < len(word):
- try:
- j = word.index(first, i)
- except ValueError:
- new_word.extend(word[i:])
- break
- else:
- new_word.extend(word[i:j])
- i = j
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
- new_word.append(first + second)
- i += 2
- else:
- new_word.append(word[i])
- i += 1
- new_word = tuple(new_word)
- word = new_word
- if len(word) == 1:
- break
- else:
- pairs = get_pairs(word)
- return word
- def get_custom_tokenizer(vocab_path,
- merges_path,
- special_tokens: list[str] = ["<|endoftext|>"]):
- with open(vocab_path, 'rb') as f:
- vocab = pickle.load(f)
- with open(merges_path, 'rb') as f:
- merges = pickle.load(f)
- tokenizer = Tokenizer(
- vocab=vocab,
- merges=merges,
- special_tokens=special_tokens
- )
- return tokenizer
- def encode_txt_as_array_slow(tokenizer, path_to_txt, save_path):
- with open(path_to_txt, 'r') as f:
- num_lines = sum(1 for _ in f)
- # 第一步:统计总token数(需要遍历一遍)
- total_tokens = 0
- with open(path_to_txt, 'r') as f:
- for line in tqdm(f, total=num_lines, desc="Counting tokens"):
- total_tokens += len(tokenizer.encode(line))
- # 第二步:创建memmap
- dtype = np.int32
- tokens_mm = np.memmap(save_path, dtype=dtype, mode='w+', shape=(total_tokens,))
- # 第三步:再次遍历写入
- pos = 0
- with open(path_to_txt, 'r') as f:
- for line in tqdm(f, total=num_lines, desc="Tokenizing"):
- ids = tokenizer.encode(line)
- n = len(ids)
- tokens_mm[pos:pos+n] = ids
- pos += n
- tokens_mm.flush()
- def batch_tokenize(batch, tokenizer):
- out = []
- for line in batch:
- out.extend(tokenizer.encode(line))
- return np.array(out, dtype=np.int32)
- def encode_txt_as_array(tokenizer, path_to_txt, save_path, batch_size=4096, n_workers=8):
- # 1.分batch
-
- batches = []
- with open(path_to_txt) as f:
- batch = []
- for line in f:
- batch.append(line)
- if len(batch) == batch_size:
- batches.append(batch)
- batch = []
- if batch:
- batches.append(batch)
-
- total_tokens = 0
- results = []
- # 2.多进程tokenize
- with ProcessPoolExecutor(max_workers=n_workers) as exe:
- futures = []
- for batch in batches:
- futures.append(exe.submit(batch_tokenize, batch, tokenizer))
- for fut in tqdm(as_completed(futures), total=len(futures), desc="Tokenizing"):
- arr = fut.result()
- results.append(arr)
- total_tokens += arr.shape[0]
-
- # 3.写memmap
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- tokens_mm = np.memmap(save_path, dtype=np.int32, mode='w+', shape=(total_tokens,))
- pos = 0
- for arr in results:
- tokens_mm[pos:pos+arr.shape[0]] = arr
- pos += arr.shape[0]
- tokens_mm.flush()
|