tokenizer.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import os
  2. import pickle
  3. import regex as re
  4. import numpy as np
  5. from tqdm import tqdm
  6. from typing import Iterable
  7. from concurrent.futures import ProcessPoolExecutor, as_completed
  8. from .utils import to_bytes_tuple, PAT
  9. class Tokenizer:
  10. def __init__(
  11. self,
  12. vocab: dict[int, bytes],
  13. merges: list[tuple[bytes, bytes]],
  14. special_tokens: list[str] | None = None,
  15. ):
  16. """
  17. Construct a BPE tokenizer from a given vocabulary, list of merges, and (optionally) special tokens.
  18. Args:
  19. vocab: A dictionary mapping token IDs to their byte representations.
  20. merges: A list of tuples representing BPE merge operations.
  21. special_tokens: Optional list of strings that should be treated as unbreakable tokens.
  22. """
  23. self.vocab = vocab
  24. self.byte_to_token_id = {v: k for k, v in vocab.items()}
  25. self.merges = merges
  26. self.bpe_ranks = dict(zip(merges, range(len(merges))))
  27. assert '<|endoftext|>' in special_tokens, "<|endoftext|> must be in special_tokens"
  28. assert special_tokens[0] == '<|endoftext|>', "<|endoftext|> must be the first token in special_tokens"
  29. self.eos_token_id = 256
  30. self.vocab_size = len(vocab)
  31. # Handle special tokens
  32. self.special_tokens = special_tokens
  33. self.special_token_bytes = [token.encode("utf-8") for token in self.special_tokens]
  34. # Ensure special tokens are in the vocabulary
  35. for token_bytes in self.special_token_bytes:
  36. if token_bytes not in self.byte_to_token_id:
  37. # Add to vocab if not already present
  38. new_id = len(self.vocab)
  39. self.vocab[new_id] = token_bytes
  40. self.byte_to_token_id[token_bytes] = new_id
  41. def encode(self, text: str) -> list[int]:
  42. """
  43. Encode an input text string into a sequence of token IDs.
  44. Args:
  45. text: The input text to encode.
  46. Returns:
  47. A list of integer token IDs representing the encoded text.
  48. """
  49. tokens = []
  50. # Sort special tokens by length (longest first) to avoid partial matches
  51. sorted_special_tokens = sorted(self.special_tokens, key=len, reverse=True)
  52. pattern = "|".join(map(re.escape, sorted_special_tokens))
  53. if pattern:
  54. parts = re.split(f"({pattern})", text)
  55. else:
  56. parts = [text]
  57. for part in parts:
  58. if part in self.special_tokens:
  59. # If it's a special token, add its ID directly
  60. tokens.append(self.byte_to_token_id[part.encode("utf-8")])
  61. else:
  62. # Otherwise, tokenize normally using BPE
  63. tokens.extend(self._tokenize_normal(part))
  64. return tokens
  65. def encode_iterable(self, iterable: Iterable[str]) -> iter:
  66. """
  67. Given an iterable of strings (e.g., a file handle), yield token IDs lazily.
  68. Args:
  69. iterable: An iterable source of text chunks.
  70. Yields:
  71. Token IDs generated by processing the input iterable.
  72. """
  73. for chunk in iterable:
  74. yield from self.encode(chunk)
  75. def decode(self, ids: list[int]) -> str:
  76. """
  77. Decode a sequence of token IDs back into a human-readable string.
  78. Args:
  79. ids: A list of integer token IDs.
  80. Returns:
  81. The decoded string representation of the input token IDs.
  82. """
  83. # Concatenate all token bytes
  84. full_bytes = b"".join(self.vocab[token_id] for token_id in ids)
  85. # Decode bytes to string, replacing invalid sequences
  86. return full_bytes.decode("utf-8", errors="replace")
  87. def _tokenize_normal(self, text: str) -> list[int]:
  88. """
  89. Tokenize a normal piece of text (not a special token) into token IDs.
  90. Args:
  91. text: A string to tokenize.
  92. Returns:
  93. A list of token IDs representing the tokenized text.
  94. """
  95. # Pre-tokenization
  96. pre_tokens = []
  97. for m in re.finditer(PAT, text):
  98. word = m.group(0)
  99. pre_tokens.append(word)
  100. token_ids = []
  101. for token in pre_tokens:
  102. # Convert token to bytes tuple
  103. byte_tuple = to_bytes_tuple(token)
  104. # Apply BPE merges
  105. merged = self._apply_merges(byte_tuple)
  106. # Get token IDs
  107. token_ids.extend(self.byte_to_token_id[b] for b in merged)
  108. return token_ids
  109. def _apply_merges(self, byte_tuple: tuple[bytes, ...]) -> list[bytes]:
  110. """
  111. Apply BPE merges to a sequence of bytes.
  112. Args:
  113. byte_tuple: A tuple of single-byte tokens.
  114. Returns:
  115. A list of merged byte tokens after applying all applicable merges.
  116. """
  117. word: list[bytes] = list(byte_tuple)
  118. def get_pairs(word: list[bytes]):
  119. pairs = set()
  120. prev_char = word[0]
  121. for char in word[1:]:
  122. pairs.add((prev_char, char))
  123. prev_char = char
  124. return pairs
  125. pairs = get_pairs(word)
  126. if not pairs:
  127. return word
  128. while True:
  129. bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
  130. if bigram not in self.bpe_ranks:
  131. break
  132. first, second = bigram
  133. new_word = []
  134. i = 0
  135. while i < len(word):
  136. try:
  137. j = word.index(first, i)
  138. except ValueError:
  139. new_word.extend(word[i:])
  140. break
  141. else:
  142. new_word.extend(word[i:j])
  143. i = j
  144. if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
  145. new_word.append(first + second)
  146. i += 2
  147. else:
  148. new_word.append(word[i])
  149. i += 1
  150. new_word = tuple(new_word)
  151. word = new_word
  152. if len(word) == 1:
  153. break
  154. else:
  155. pairs = get_pairs(word)
  156. return word
  157. def get_custom_tokenizer(vocab_path,
  158. merges_path,
  159. special_tokens: list[str] = ["<|endoftext|>"]):
  160. with open(vocab_path, 'rb') as f:
  161. vocab = pickle.load(f)
  162. with open(merges_path, 'rb') as f:
  163. merges = pickle.load(f)
  164. tokenizer = Tokenizer(
  165. vocab=vocab,
  166. merges=merges,
  167. special_tokens=special_tokens
  168. )
  169. return tokenizer
  170. def encode_txt_as_array_slow(tokenizer, path_to_txt, save_path):
  171. with open(path_to_txt, 'r') as f:
  172. num_lines = sum(1 for _ in f)
  173. # 第一步:统计总token数(需要遍历一遍)
  174. total_tokens = 0
  175. with open(path_to_txt, 'r') as f:
  176. for line in tqdm(f, total=num_lines, desc="Counting tokens"):
  177. total_tokens += len(tokenizer.encode(line))
  178. # 第二步:创建memmap
  179. dtype = np.int32
  180. tokens_mm = np.memmap(save_path, dtype=dtype, mode='w+', shape=(total_tokens,))
  181. # 第三步:再次遍历写入
  182. pos = 0
  183. with open(path_to_txt, 'r') as f:
  184. for line in tqdm(f, total=num_lines, desc="Tokenizing"):
  185. ids = tokenizer.encode(line)
  186. n = len(ids)
  187. tokens_mm[pos:pos+n] = ids
  188. pos += n
  189. tokens_mm.flush()
  190. def batch_tokenize(batch, tokenizer):
  191. out = []
  192. for line in batch:
  193. out.extend(tokenizer.encode(line))
  194. return np.array(out, dtype=np.int32)
  195. def encode_txt_as_array(tokenizer, path_to_txt, save_path, batch_size=4096, n_workers=8):
  196. # 1.分batch
  197. batches = []
  198. with open(path_to_txt) as f:
  199. batch = []
  200. for line in f:
  201. batch.append(line)
  202. if len(batch) == batch_size:
  203. batches.append(batch)
  204. batch = []
  205. if batch:
  206. batches.append(batch)
  207. total_tokens = 0
  208. results = []
  209. # 2.多进程tokenize
  210. with ProcessPoolExecutor(max_workers=n_workers) as exe:
  211. futures = []
  212. for batch in batches:
  213. futures.append(exe.submit(batch_tokenize, batch, tokenizer))
  214. for fut in tqdm(as_completed(futures), total=len(futures), desc="Tokenizing"):
  215. arr = fut.result()
  216. results.append(arr)
  217. total_tokens += arr.shape[0]
  218. # 3.写memmap
  219. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  220. tokens_mm = np.memmap(save_path, dtype=np.int32, mode='w+', shape=(total_tokens,))
  221. pos = 0
  222. for arr in results:
  223. tokens_mm[pos:pos+arr.shape[0]] = arr
  224. pos += arr.shape[0]
  225. tokens_mm.flush()