train_sft_v0.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import os
  2. import time
  3. import torch
  4. import hydra
  5. import mlflow
  6. from pathlib import Path
  7. from typing import Dict, List
  8. from omegaconf import DictConfig
  9. from tqdm.auto import tqdm
  10. from datasets import load_dataset
  11. from transformers import AutoModelForCausalLM, AutoTokenizer
  12. from clean_llm.train.sft import (
  13. run_tokenize_prompt_and_output,
  14. run_get_response_log_probs,
  15. run_sft_microbatch_train_step,
  16. run_parse_gsm8k_response,
  17. evaluate_gsm8k
  18. )
  19. from clean_llm.utils import _to_device_and_compile, log_params_from_omegaconf_dict
  20. def save_checkpoint(model, tokenizer, save_dir: Path):
  21. save_dir.mkdir(parents=True, exist_ok=True)
  22. model.save_pretrained(save_dir)
  23. tokenizer.save_pretrained(save_dir)
  24. print(f"✅ 保存模型到 {save_dir}")
  25. def train(
  26. cfg: DictConfig,
  27. model: AutoModelForCausalLM,
  28. tokenizer: AutoTokenizer,
  29. train_prompt_strs: List[str],
  30. train_output_strs: List[str],
  31. test_prompt_strs: List[str],
  32. test_output_strs: List[str],
  33. ):
  34. os.makedirs(cfg.checkpoint_dir, exist_ok=True)
  35. os.makedirs(cfg.csv_dir, exist_ok=True)
  36. # 支持梯度累积
  37. micro_bs = cfg.micro_batch_size
  38. grad_accum_steps = cfg.gradient_accumulation_steps
  39. global_batch_size = micro_bs * grad_accum_steps
  40. num_epochs = cfg.num_epochs
  41. max_steps = cfg.get("max_steps", None) # 可选,按 step 训练
  42. eval_steps = cfg.eval_steps
  43. save_steps = cfg.save_steps
  44. project_name = cfg.get("mlflow_project", "sft-gsm8k")
  45. run_name = cfg.get("mlflow_run", f"run-{int(time.time())}")
  46. # ------------ MLflow ------------
  47. mlflow.set_experiment(project_name)
  48. mlflow.start_run(run_name=run_name)
  49. mlflow.log_params(cfg)
  50. # ------------ 数据 ------------
  51. train_inputs = run_tokenize_prompt_and_output(
  52. train_prompt_strs, train_output_strs, tokenizer
  53. )
  54. for k, v in train_inputs.items():
  55. train_inputs[k] = v.to(cfg.train_device)
  56. # ------------ 训练状态 ------------
  57. optimizer = torch.optim.AdamW(
  58. model.parameters(),
  59. lr=cfg.learning_rate,
  60. weight_decay=cfg.weight_decay,
  61. )
  62. step = 0
  63. epoch = 0
  64. model.train()
  65. # ------------ 训练循环 ------------
  66. total_samples = len(train_prompt_strs)
  67. samples_seen = 0
  68. pbar = tqdm(total=total_samples * num_epochs, desc="Training")
  69. while True:
  70. for i in range(0, total_samples, micro_bs):
  71. # 构造 micro-batch
  72. end = min(i + micro_bs, total_samples)
  73. batch_input_ids = train_inputs["input_ids"][i:end]
  74. batch_labels = train_inputs["labels"][i:end]
  75. batch_response_mask = train_inputs["response_mask"][i:end]
  76. # 前向 & 计算 loss
  77. res = run_get_response_log_probs(
  78. model,
  79. batch_input_ids,
  80. batch_labels,
  81. return_token_entropy=False,
  82. )
  83. loss, _ = run_sft_microbatch_train_step(
  84. res["log_probs"],
  85. batch_response_mask,
  86. gradient_accumulation_steps=grad_accum_steps,
  87. normalize_constant=1,
  88. )
  89. samples_seen += end - i
  90. pbar.update(end - i)
  91. # 梯度累积
  92. if (step + 1) % grad_accum_steps == 0:
  93. optimizer.step()
  94. optimizer.zero_grad()
  95. # ------------ 日志 ------------
  96. mlflow.log_metric("loss", loss.item(), step=step)
  97. # ------------ 评估 ------------
  98. if (step + 1) % eval_steps == 0:
  99. eval_res = evaluate_gsm8k(model_id=cfg.model_path,
  100. policy=model,
  101. tokenizer=tokenizer,
  102. prompt_strs=test_prompt_strs,
  103. output_strs=test_output_strs,
  104. save_path=os.path.join(cfg.csv_dir, f'step_{step}.csv'),
  105. device=cfg.eval_device
  106. )
  107. mlflow.log_metrics(eval_res, step=step)
  108. model.train()
  109. # ------------ 保存 ------------
  110. if (step + 1) % save_steps == 0:
  111. ckpt_dir = Path(cfg.checkpoint_dir) / f"checkpoint-{step}"
  112. save_checkpoint(model, tokenizer, ckpt_dir)
  113. step += 1
  114. # 可选按 step 终止
  115. if max_steps and step >= max_steps:
  116. break
  117. epoch += 1
  118. if epoch >= num_epochs or (max_steps and step >= max_steps):
  119. break
  120. pbar.close()
  121. # ------------ 训练结束 ------------
  122. final_ckpt_dir = Path(cfg.checkpoint_dir) / "final"
  123. save_checkpoint(model, tokenizer, final_ckpt_dir)
  124. mlflow.end_run()
  125. print("🎉 训练完成")
  126. # ---------------- Hydra 入口 ----------------
  127. @hydra.main(config_path="configs", config_name="sft_gsm8k", version_base=None)
  128. def main(cfg: DictConfig):
  129. model = AutoModelForCausalLM.from_pretrained(
  130. cfg.model_path, torch_dtype=torch.bfloat16, device_map=cfg.train_device
  131. )
  132. tokenizer = AutoTokenizer.from_pretrained(cfg.model_path, padding_side='left')
  133. # model, device = _to_device_and_compile(model, cfg.train_device)
  134. print(f"Load huggingface model from {cfg.model_path} on device {cfg.train_device}")
  135. dataset = load_dataset(cfg.dataset_path, "main")
  136. train_prompt_strs = [ex["question"] for ex in dataset["train"]]
  137. train_output_strs = [ex["answer"] for ex in dataset["train"]]
  138. test_prompt_strs = [ex["question"] for ex in dataset["test"]]
  139. test_output_strs = [ex["answer"] for ex in dataset["test"]]
  140. print(f"Train samples: {len(train_prompt_strs)}, Test samples: {len(test_prompt_strs)}")
  141. train(
  142. cfg,
  143. model,
  144. tokenizer,
  145. train_prompt_strs,
  146. train_output_strs,
  147. test_prompt_strs,
  148. test_output_strs,
  149. )
  150. if __name__ == "__main__":
  151. main()