train.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from __future__ import annotations
  2. import os
  3. import regex as re
  4. from tqdm import tqdm
  5. from collections import defaultdict
  6. from typing import Dict
  7. from .utils import to_bytes_tuple, PAT
  8. def run_train_bpe(
  9. input_path: str | os.PathLike,
  10. vocab_size: int,
  11. special_tokens: list[str],
  12. **kwargs,
  13. ) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
  14. """Given the path to an input corpus, run train a BPE tokenizer and
  15. output its vocabulary and merges.
  16. Args:
  17. input_path (str | os.PathLike): Path to BPE tokenizer training data.
  18. vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
  19. special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
  20. These strings will never be split into multiple tokens, and will always be
  21. kept as a single token. If these special tokens occur in the `input_path`,
  22. they are treated as any other string.
  23. Returns:
  24. tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
  25. vocab:
  26. The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
  27. to bytes (token bytes)
  28. merges:
  29. BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
  30. representing that <token1> was merged with <token2>.
  31. Merges are ordered by order of creation.
  32. """
  33. # Step 1: Initialize Vocabulary
  34. vocab: Dict[int, bytes] = {i: bytes([i]) for i in range(256)}
  35. next_id = 256
  36. special_token_bytes = [token.encode("utf-8") for token in special_tokens]
  37. for token_bytes in special_token_bytes:
  38. if token_bytes not in vocab.values():
  39. vocab[next_id] = token_bytes
  40. next_id += 1
  41. # Step 2: Pre-tokenization
  42. pre_tokens_cnt = defaultdict(int)
  43. with open(input_path, "r", encoding="utf-8") as f:
  44. text = f.read()
  45. chunks = re.split("|".join(map(re.escape, special_tokens)), text)
  46. for chunk in tqdm(chunks, desc="Pretokenize"):
  47. for m in re.finditer(PAT, chunk):
  48. word = m.group(0)
  49. pre_tokens_cnt[to_bytes_tuple(word)] += 1 # key of pre_tokens_cnt e.g. (b'H', b'e', b'l', b'l', b'o')
  50. # Step 3: Compute BPE Merges
  51. merges = []
  52. init_vocab_size = len(vocab)
  53. num_merges = vocab_size - init_vocab_size
  54. # tqdm接管循环
  55. for _ in tqdm(range(num_merges), desc="BPE merges"):
  56. pair_counts = defaultdict(int)
  57. # Count all adjacent byte pairs
  58. for token, cnt in pre_tokens_cnt.items(): # token e.g. "hello"
  59. for i in range(len(token) - 1):
  60. pair = (token[i], token[i + 1])
  61. pair_counts[pair] += cnt
  62. if not pair_counts:
  63. break # No more pairs to merge
  64. # Find the most frequent pair(s)
  65. max_count = max(pair_counts.values())
  66. candidates = [k for k, v in pair_counts.items() if v == max_count]
  67. best_pair = max(candidates)
  68. a, b = best_pair
  69. # Create new token
  70. new_token = a + b
  71. vocab[next_id] = new_token
  72. next_id += 1
  73. # Apply the merge to all pre-tokenized sequences
  74. # 收集变更
  75. changes = []
  76. for token, cnt in pre_tokens_cnt.items():
  77. # Find all occurrences of the `best_pair` in `token`
  78. indices = [i for i in range(len(token) - 1) if token[i:i + 2] == best_pair]
  79. if indices:
  80. # Replace each occurrence with `new_token`
  81. new_pre_token = []
  82. i = 0
  83. while i < len(token):
  84. if i in indices:
  85. new_pre_token.append(new_token)
  86. i += 2
  87. else:
  88. new_pre_token.append(token[i])
  89. i += 1
  90. new_pre_token = tuple(new_pre_token)
  91. changes.append((token, new_pre_token, cnt))
  92. # 应用变更
  93. for old_token, new_pre_token, cnt in changes:
  94. pre_tokens_cnt[new_pre_token] = pre_tokens_cnt.get(new_pre_token, 0) + cnt
  95. del pre_tokens_cnt[old_token]
  96. # Record the merge
  97. merges.append((a, b))
  98. return vocab, merges