pretrain.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import hydra
  2. import mlflow
  3. from omegaconf import DictConfig
  4. import warnings
  5. warnings.filterwarnings("ignore")
  6. from clean_llm.models.qwen2_5 import Qwen2_5
  7. from clean_llm.models.cs336_lm import BasicsTransformerLM
  8. from clean_llm.train.pretrain import train
  9. from clean_llm.tokenizer.tokenizer import get_custom_tokenizer
  10. from clean_llm.utils import _to_device_and_compile, log_params_from_omegaconf_dict
  11. @hydra.main(config_path="configs/", config_name="pretrain_cs336_lm", version_base=None)
  12. def main(cfg: DictConfig):
  13. mlflow.set_experiment(cfg.exp_name)
  14. mlflow.start_run()
  15. if cfg.model_type == "qwen2_5":
  16. model_config, training_config, tokenizer_config = cfg.model, cfg.training, cfg.tokenizer
  17. tokenizer = get_custom_tokenizer(**tokenizer_config)
  18. model_config.vocab_size = tokenizer.vocab_size
  19. model_config.eos_token_id = tokenizer.eos_token_id
  20. model = Qwen2_5.from_config(model_config)
  21. elif cfg.model_type == "cs336_lm":
  22. model_config, training_config = cfg.model, cfg.training
  23. model = BasicsTransformerLM(**model_config)
  24. model, device = _to_device_and_compile(model)
  25. train(model, device, training_config)
  26. mlflow.end_run()
  27. if __name__ == '__main__':
  28. main()