train_tokenizer.py 941 B

1234567891011121314151617181920212223242526272829303132
  1. import os
  2. import pickle
  3. import hydra
  4. from omegaconf import DictConfig
  5. # from clean_llm.tokenizer.train import run_train_bpe # slow version
  6. from clean_llm.tokenizer.train_fast import run_train_bpe # fast version
  7. @hydra.main(config_path="configs", config_name="tokenizer", version_base=None)
  8. def main(cfg: DictConfig):
  9. vocab, merges = run_train_bpe(
  10. input_path=cfg.input_path,
  11. vocab_size=cfg.vocab_size,
  12. special_tokens=cfg.special_tokens,
  13. num_chunks=cfg.num_chunks,
  14. num_processes=cfg.num_processes
  15. )
  16. os.makedirs(cfg.tokenizer_dir, exist_ok=True)
  17. with open(cfg.vocab_path, "wb") as f:
  18. pickle.dump(vocab, f)
  19. with open(cfg.merges_path, "wb") as f:
  20. pickle.dump(merges, f)
  21. # 统计最长token
  22. longest_token = max(vocab.values(), key=len)
  23. print("最长token:", longest_token, "长度:", len(longest_token))
  24. if __name__ == "__main__":
  25. main()