train_fast.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import os
  2. import re
  3. import time
  4. import heapq
  5. from collections import defaultdict
  6. from collections import Counter
  7. from multiprocessing import Pool, cpu_count
  8. from tqdm import tqdm
  9. from typing import Dict, Tuple, Union, Pattern
  10. PAT_COMPILED = re.compile(r"\S+")
  11. def pre_tokenize_and_count(
  12. args: Tuple[bytes, Dict[str, int], Union[Pattern, None]]
  13. ) -> Counter:
  14. """
  15. Pre-tokenize a chunk of bytes into token ids, handling special tokens.
  16. Returns a Counter of tokens.
  17. """
  18. chunk_bytes, special_token_to_id, delimiter_pattern_compiled = args
  19. chunk = chunk_bytes.decode("utf-8", errors="ignore")
  20. special_tokens_set = set(special_token_to_id.keys())
  21. words_list = []
  22. if delimiter_pattern_compiled:
  23. sub_chunks = delimiter_pattern_compiled.split(chunk)
  24. else:
  25. sub_chunks = [chunk]
  26. for sub_chunk in tqdm(sub_chunks, desc="Pre-tokenizing subchunks"):
  27. if not sub_chunk:
  28. continue
  29. if sub_chunk in special_tokens_set:
  30. token_id = special_token_to_id[sub_chunk]
  31. words_list.append((token_id,))
  32. else:
  33. for word_str in PAT_COMPILED.findall(sub_chunk):
  34. if word_str:
  35. byte_sequence = word_str.encode("utf-8")
  36. id_sequence = tuple(byte_sequence)
  37. words_list.append(id_sequence)
  38. return Counter(words_list)
  39. # Placeholder, should be replaced with your chunk splitter logic
  40. def find_chunk_boundaries(f, num_chunks, marker_bytes):
  41. f.seek(0, 2)
  42. file_size = f.tell()
  43. chunk_size = file_size // num_chunks
  44. boundaries = [0]
  45. for i in range(1, num_chunks):
  46. boundaries.append(min(i * chunk_size, file_size))
  47. boundaries.append(file_size)
  48. return boundaries
  49. class Node:
  50. """表示词内一个 token 节点,便于链表原地更新。"""
  51. def __init__(self, value, word_freq):
  52. self.value = value
  53. self.word_freq = word_freq # 共享引用,节省内存
  54. self.prev = None
  55. self.next = None
  56. class PQItem:
  57. """定义优先队列元素,实现自定义比较:频率优先,其次按字典序逆序。"""
  58. def __init__(self, freq, id_pair, byte_pair):
  59. self.freq = freq
  60. self.id_pair = id_pair
  61. self.byte_pair = byte_pair
  62. def __lt__(self, other):
  63. if self.freq != other.freq:
  64. return self.freq > other.freq # 频率高的先出
  65. return self.byte_pair > other.byte_pair # 字典序大的先出
  66. def run_train_bpe(
  67. input_path: str | os.PathLike,
  68. vocab_size: int,
  69. special_tokens: list[str],
  70. num_chunks: int = 4,
  71. num_processes: int = None,
  72. **kwargs,
  73. ) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
  74. before_pretokenization_time = time.time()
  75. # 1. Set up initial byte vocab and special tokens
  76. vocab = {i: bytes([i]) for i in range(256)}
  77. for token in special_tokens:
  78. token_bytes = token.encode("utf-8")
  79. if token_bytes not in vocab.values():
  80. vocab[len(vocab)] = token_bytes
  81. byte_to_token_id = {v: k for k, v in vocab.items()}
  82. special_token_to_id = {
  83. token: byte_to_token_id[token.encode("utf-8")] for token in special_tokens
  84. }
  85. # 2. Prepare special tokens regex delimiter
  86. delimiter_pattern_compiled = None
  87. if special_tokens:
  88. # Sort by length descending for proper greedy match
  89. special_tokens_sorted = sorted(
  90. [t.encode("utf-8") for t in special_tokens], key=len, reverse=True
  91. )
  92. escaped_tokens = [re.escape(t.decode("utf-8")) for t in special_tokens_sorted]
  93. delimiter_re = "|".join(escaped_tokens)
  94. if delimiter_re:
  95. delimiter_pattern_compiled = re.compile(f"({delimiter_re})")
  96. # 3. Read file, split into chunks for multiprocessing
  97. with open(input_path, "rb") as f:
  98. boundaries = find_chunk_boundaries(
  99. f, num_chunks, "<|endoftext|>".encode("utf-8")
  100. )
  101. chunk_args = []
  102. for start, end in zip(boundaries[:-1], boundaries[1:]):
  103. f.seek(start)
  104. chunk_bytes = f.read(end - start)
  105. chunk_args.append(
  106. (
  107. chunk_bytes,
  108. special_token_to_id,
  109. delimiter_pattern_compiled,
  110. )
  111. )
  112. # 4. Determine number of processes
  113. processes_to_use = num_processes
  114. if processes_to_use is None:
  115. processes_to_use = min(cpu_count(), 8)
  116. processes_to_use = min(processes_to_use, len(chunk_args))
  117. elapsed = time.time() - before_pretokenization_time
  118. print(f"Time taken before pretokenization: {elapsed:.2f} seconds")
  119. # 5. Multiprocess: Counter aggregation
  120. all_word_freqs = Counter()
  121. start_time = time.time()
  122. with Pool(processes=processes_to_use) as pool:
  123. print(
  124. f"Starting pre-tokenization with {processes_to_use} processes on {len(chunk_args)} chunks..."
  125. )
  126. results_iterator = pool.imap_unordered(pre_tokenize_and_count, chunk_args)
  127. for chunk_counter in tqdm(
  128. results_iterator, total=len(chunk_args), desc="Processing chunks", leave=True
  129. ):
  130. all_word_freqs.update(chunk_counter)
  131. print(f"Pre-tokenization and initial counting time: {time.time() - start_time:.2f} seconds")
  132. ### Pre-tokenization 结束
  133. pair_to_nodes = defaultdict(set)
  134. for word_tuple, count in tqdm(all_word_freqs.items(), desc="Building", leave=True):
  135. if len(word_tuple) < 2:
  136. continue
  137. # 所有链表节点共享 word_freq 引用,节省内存
  138. word_freq = {'count': count}
  139. head = Node(word_tuple[0], word_freq)
  140. prev_node = head
  141. for i in range(1, len(word_tuple)):
  142. curr_node = Node(word_tuple[i], word_freq)
  143. prev_node.next = curr_node
  144. curr_node.prev = prev_node
  145. pair = (prev_node.value, curr_node.value)
  146. pair_to_nodes[pair].add(prev_node)
  147. prev_node = curr_node
  148. pair_freqs = Counter()
  149. for pair, nodes in tqdm(pair_to_nodes.items(), desc="Counting pairs", leave=True):
  150. # 某个 pair 的出现次数 = 其所有节点所对应 word 的词频累加
  151. pair_freqs[pair] = sum(node.word_freq['count'] for node in nodes)
  152. pq = [
  153. PQItem(freq, p, (vocab[p[0]], vocab[p[1]]))
  154. for p, freq in pair_freqs.items()
  155. ]
  156. heapq.heapify(pq)
  157. ### BPE 开始
  158. merges = []
  159. num_merges = vocab_size - len(vocab)
  160. pbar = tqdm(total=num_merges, desc="Performing BPE merges")
  161. start_time = time.time()
  162. for _ in range(num_merges):
  163. if not pq:
  164. break
  165. # 取出频率最高的 pair,处理优先队列惰性删除的过期元素
  166. best_pair = None
  167. while pq:
  168. item = heapq.heappop(pq)
  169. if item.id_pair not in pair_freqs:
  170. continue # 已经被合并删除
  171. if pair_freqs[item.id_pair] == item.freq:
  172. best_pair = item.id_pair
  173. break
  174. if best_pair is None:
  175. break
  176. p1, p2 = best_pair
  177. # 合成新 token,添加到 merges/vocab
  178. new_token_id = len(vocab)
  179. merged_token_bytes = vocab[p1] + vocab[p2]
  180. merges.append((vocab[p1], vocab[p2]))
  181. vocab[new_token_id] = merged_token_bytes
  182. # 逐个更新包含改 pair 的词
  183. nodes_to_process = list(pair_to_nodes[best_pair])
  184. for node1 in nodes_to_process:
  185. node2 = node1.next
  186. if node2 is None:
  187. continue
  188. word_freq = node1.word_freq['count']
  189. # 更新左侧相邻 pair 的频率及映射关系
  190. if node1.prev:
  191. left = node1.prev
  192. old_left_pair = (left.value, node1.value)
  193. pair_freqs[old_left_pair] -= word_freq
  194. heapq.heappush(pq, PQItem(pair_freqs[old_left_pair], old_left_pair, (vocab[old_left_pair[0]], vocab[old_left_pair[1]])))
  195. pair_to_nodes[old_left_pair].discard(left)
  196. new_left_pair = (left.value, new_token_id)
  197. pair_to_nodes[new_left_pair].add(left)
  198. pair_freqs[new_left_pair] += word_freq
  199. heapq.heappush(pq, PQItem(pair_freqs[new_left_pair], new_left_pair, (vocab[new_left_pair[0]], vocab[new_left_pair[1]])))
  200. # 更新右侧相邻 pair 的频率及映射关系
  201. if node2.next:
  202. right = node2.next
  203. old_right_pair = (node2.value, right.value)
  204. pair_freqs[old_right_pair] -= word_freq
  205. heapq.heappush(pq, PQItem(pair_freqs[old_right_pair], old_right_pair, (vocab[old_right_pair[0]], vocab[old_right_pair[1]])))
  206. new_right_pair = (new_token_id, right.value)
  207. pair_to_nodes[old_right_pair].discard(node2)
  208. pair_to_nodes[new_right_pair].add(node1)
  209. pair_freqs[new_right_pair] += word_freq
  210. heapq.heappush(pq, PQItem(pair_freqs[new_right_pair], new_right_pair, (vocab[new_right_pair[0]], vocab[new_right_pair[1]])))
  211. # 链表合并:node1、node2合成 new_token_id
  212. node1.value = new_token_id
  213. node1.next = node2.next
  214. if node2.next:
  215. node2.next.prev = node1
  216. # 删除被合并 pair 的所有统计
  217. del pair_freqs[best_pair]
  218. del pair_to_nodes[best_pair]
  219. pbar.update(1)
  220. end_time = time.time()
  221. print(f"Merge time: {end_time - start_time:.2f} seconds")
  222. pbar.close()
  223. return vocab, merges