| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278 |
- import os
- import re
- import time
- import heapq
- from collections import defaultdict
- from collections import Counter
- from multiprocessing import Pool, cpu_count
- from tqdm import tqdm
- from typing import Dict, Tuple, Union, Pattern
- PAT_COMPILED = re.compile(r"\S+")
- def pre_tokenize_and_count(
- args: Tuple[bytes, Dict[str, int], Union[Pattern, None]]
- ) -> Counter:
- """
- Pre-tokenize a chunk of bytes into token ids, handling special tokens.
- Returns a Counter of tokens.
- """
- chunk_bytes, special_token_to_id, delimiter_pattern_compiled = args
- chunk = chunk_bytes.decode("utf-8", errors="ignore")
- special_tokens_set = set(special_token_to_id.keys())
-
- words_list = []
-
- if delimiter_pattern_compiled:
- sub_chunks = delimiter_pattern_compiled.split(chunk)
- else:
- sub_chunks = [chunk]
- for sub_chunk in tqdm(sub_chunks, desc="Pre-tokenizing subchunks"):
- if not sub_chunk:
- continue
- if sub_chunk in special_tokens_set:
- token_id = special_token_to_id[sub_chunk]
- words_list.append((token_id,))
- else:
- for word_str in PAT_COMPILED.findall(sub_chunk):
- if word_str:
- byte_sequence = word_str.encode("utf-8")
- id_sequence = tuple(byte_sequence)
- words_list.append(id_sequence)
- return Counter(words_list)
- # Placeholder, should be replaced with your chunk splitter logic
- def find_chunk_boundaries(f, num_chunks, marker_bytes):
- f.seek(0, 2)
- file_size = f.tell()
- chunk_size = file_size // num_chunks
- boundaries = [0]
- for i in range(1, num_chunks):
- boundaries.append(min(i * chunk_size, file_size))
- boundaries.append(file_size)
- return boundaries
- class Node:
- """表示词内一个 token 节点,便于链表原地更新。"""
- def __init__(self, value, word_freq):
- self.value = value
- self.word_freq = word_freq # 共享引用,节省内存
- self.prev = None
- self.next = None
- class PQItem:
- """定义优先队列元素,实现自定义比较:频率优先,其次按字典序逆序。"""
- def __init__(self, freq, id_pair, byte_pair):
- self.freq = freq
- self.id_pair = id_pair
- self.byte_pair = byte_pair
- def __lt__(self, other):
- if self.freq != other.freq:
- return self.freq > other.freq # 频率高的先出
- return self.byte_pair > other.byte_pair # 字典序大的先出
- def run_train_bpe(
- input_path: str | os.PathLike,
- vocab_size: int,
- special_tokens: list[str],
- num_chunks: int = 4,
- num_processes: int = None,
- **kwargs,
- ) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
- before_pretokenization_time = time.time()
- # 1. Set up initial byte vocab and special tokens
- vocab = {i: bytes([i]) for i in range(256)}
- for token in special_tokens:
- token_bytes = token.encode("utf-8")
- if token_bytes not in vocab.values():
- vocab[len(vocab)] = token_bytes
- byte_to_token_id = {v: k for k, v in vocab.items()}
- special_token_to_id = {
- token: byte_to_token_id[token.encode("utf-8")] for token in special_tokens
- }
- # 2. Prepare special tokens regex delimiter
- delimiter_pattern_compiled = None
- if special_tokens:
- # Sort by length descending for proper greedy match
- special_tokens_sorted = sorted(
- [t.encode("utf-8") for t in special_tokens], key=len, reverse=True
- )
- escaped_tokens = [re.escape(t.decode("utf-8")) for t in special_tokens_sorted]
- delimiter_re = "|".join(escaped_tokens)
- if delimiter_re:
- delimiter_pattern_compiled = re.compile(f"({delimiter_re})")
- # 3. Read file, split into chunks for multiprocessing
- with open(input_path, "rb") as f:
- boundaries = find_chunk_boundaries(
- f, num_chunks, "<|endoftext|>".encode("utf-8")
- )
- chunk_args = []
- for start, end in zip(boundaries[:-1], boundaries[1:]):
- f.seek(start)
- chunk_bytes = f.read(end - start)
- chunk_args.append(
- (
- chunk_bytes,
- special_token_to_id,
- delimiter_pattern_compiled,
- )
- )
- # 4. Determine number of processes
- processes_to_use = num_processes
- if processes_to_use is None:
- processes_to_use = min(cpu_count(), 8)
- processes_to_use = min(processes_to_use, len(chunk_args))
- elapsed = time.time() - before_pretokenization_time
- print(f"Time taken before pretokenization: {elapsed:.2f} seconds")
- # 5. Multiprocess: Counter aggregation
- all_word_freqs = Counter()
- start_time = time.time()
- with Pool(processes=processes_to_use) as pool:
- print(
- f"Starting pre-tokenization with {processes_to_use} processes on {len(chunk_args)} chunks..."
- )
- results_iterator = pool.imap_unordered(pre_tokenize_and_count, chunk_args)
- for chunk_counter in tqdm(
- results_iterator, total=len(chunk_args), desc="Processing chunks", leave=True
- ):
- all_word_freqs.update(chunk_counter)
- print(f"Pre-tokenization and initial counting time: {time.time() - start_time:.2f} seconds")
- ### Pre-tokenization 结束
- pair_to_nodes = defaultdict(set)
- for word_tuple, count in tqdm(all_word_freqs.items(), desc="Building", leave=True):
- if len(word_tuple) < 2:
- continue
- # 所有链表节点共享 word_freq 引用,节省内存
- word_freq = {'count': count}
- head = Node(word_tuple[0], word_freq)
- prev_node = head
- for i in range(1, len(word_tuple)):
- curr_node = Node(word_tuple[i], word_freq)
- prev_node.next = curr_node
- curr_node.prev = prev_node
- pair = (prev_node.value, curr_node.value)
- pair_to_nodes[pair].add(prev_node)
- prev_node = curr_node
- pair_freqs = Counter()
- for pair, nodes in tqdm(pair_to_nodes.items(), desc="Counting pairs", leave=True):
- # 某个 pair 的出现次数 = 其所有节点所对应 word 的词频累加
- pair_freqs[pair] = sum(node.word_freq['count'] for node in nodes)
- pq = [
- PQItem(freq, p, (vocab[p[0]], vocab[p[1]]))
- for p, freq in pair_freqs.items()
- ]
- heapq.heapify(pq)
- ### BPE 开始
- merges = []
- num_merges = vocab_size - len(vocab)
- pbar = tqdm(total=num_merges, desc="Performing BPE merges")
- start_time = time.time()
- for _ in range(num_merges):
- if not pq:
- break
- # 取出频率最高的 pair,处理优先队列惰性删除的过期元素
- best_pair = None
- while pq:
- item = heapq.heappop(pq)
- if item.id_pair not in pair_freqs:
- continue # 已经被合并删除
- if pair_freqs[item.id_pair] == item.freq:
- best_pair = item.id_pair
- break
- if best_pair is None:
- break
- p1, p2 = best_pair
- # 合成新 token,添加到 merges/vocab
- new_token_id = len(vocab)
- merged_token_bytes = vocab[p1] + vocab[p2]
- merges.append((vocab[p1], vocab[p2]))
- vocab[new_token_id] = merged_token_bytes
- # 逐个更新包含改 pair 的词
- nodes_to_process = list(pair_to_nodes[best_pair])
- for node1 in nodes_to_process:
- node2 = node1.next
- if node2 is None:
- continue
- word_freq = node1.word_freq['count']
- # 更新左侧相邻 pair 的频率及映射关系
- if node1.prev:
- left = node1.prev
- old_left_pair = (left.value, node1.value)
- pair_freqs[old_left_pair] -= word_freq
- heapq.heappush(pq, PQItem(pair_freqs[old_left_pair], old_left_pair, (vocab[old_left_pair[0]], vocab[old_left_pair[1]])))
- pair_to_nodes[old_left_pair].discard(left)
- new_left_pair = (left.value, new_token_id)
- pair_to_nodes[new_left_pair].add(left)
- pair_freqs[new_left_pair] += word_freq
- heapq.heappush(pq, PQItem(pair_freqs[new_left_pair], new_left_pair, (vocab[new_left_pair[0]], vocab[new_left_pair[1]])))
- # 更新右侧相邻 pair 的频率及映射关系
- if node2.next:
- right = node2.next
- old_right_pair = (node2.value, right.value)
- pair_freqs[old_right_pair] -= word_freq
- heapq.heappush(pq, PQItem(pair_freqs[old_right_pair], old_right_pair, (vocab[old_right_pair[0]], vocab[old_right_pair[1]])))
- new_right_pair = (new_token_id, right.value)
- pair_to_nodes[old_right_pair].discard(node2)
- pair_to_nodes[new_right_pair].add(node1)
- pair_freqs[new_right_pair] += word_freq
- heapq.heappush(pq, PQItem(pair_freqs[new_right_pair], new_right_pair, (vocab[new_right_pair[0]], vocab[new_right_pair[1]])))
- # 链表合并:node1、node2合成 new_token_id
- node1.value = new_token_id
- node1.next = node2.next
- if node2.next:
- node2.next.prev = node1
- # 删除被合并 pair 的所有统计
- del pair_freqs[best_pair]
- del pair_to_nodes[best_pair]
- pbar.update(1)
- end_time = time.time()
- print(f"Merge time: {end_time - start_time:.2f} seconds")
- pbar.close()
-
- return vocab, merges
|