| 1234567891011121314151617181920212223242526272829303132333435363738 |
- import hydra
- import mlflow
- from omegaconf import DictConfig
- import warnings
- warnings.filterwarnings("ignore")
- from clean_llm.models.qwen2_5 import Qwen2_5
- from clean_llm.models.cs336_lm import BasicsTransformerLM
- from clean_llm.train.pretrain import train
- from clean_llm.tokenizer.tokenizer import get_custom_tokenizer
- from clean_llm.utils import _to_device_and_compile, log_params_from_omegaconf_dict
- @hydra.main(config_path="configs/", config_name="pretrain_cs336_lm", version_base=None)
- def main(cfg: DictConfig):
- mlflow.set_experiment(cfg.exp_name)
- mlflow.start_run()
- if cfg.model_type == "qwen2_5":
- model_config, training_config, tokenizer_config = cfg.model, cfg.training, cfg.tokenizer
- tokenizer = get_custom_tokenizer(**tokenizer_config)
- model_config.vocab_size = tokenizer.vocab_size
- model_config.eos_token_id = tokenizer.eos_token_id
- model = Qwen2_5.from_config(model_config)
- elif cfg.model_type == "cs336_lm":
- model_config, training_config = cfg.model, cfg.training
- model = BasicsTransformerLM(**model_config)
- model, device = _to_device_and_compile(model)
- train(model, device, training_config)
- mlflow.end_run()
- if __name__ == '__main__':
- main()
|