pretrain.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. import sys
  3. import mlflow
  4. import torch
  5. import torch.nn.functional as F
  6. import pathlib
  7. import numpy as np
  8. from tqdm import tqdm
  9. from hydra.core.hydra_config import HydraConfig
  10. from .adapters import *
  11. def get_memmap_dataset(path, dtype=np.int32):
  12. arr = np.memmap(path, dtype=dtype, mode="r") # 单列token id序列
  13. return arr
  14. def get_batch(memmap_arr, batch_size, context_length):
  15. N = len(memmap_arr)
  16. ix = np.random.randint(0, N-context_length-1, size=(batch_size,))
  17. x = np.stack([memmap_arr[i:i+context_length] for i in ix])
  18. y = np.stack([memmap_arr[i+1:i+context_length+1] for i in ix])
  19. return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
  20. def memmap_val_iterator(memmap_arr, batch_size, context_length):
  21. N = len(memmap_arr)
  22. nb = (N-context_length-1)//batch_size
  23. for bi in range(nb):
  24. base = bi*batch_size
  25. x = np.stack([memmap_arr[i:i+context_length] for i in range(base, base+batch_size)])
  26. y = np.stack([memmap_arr[i+1:i+context_length+1] for i in range(base, base+batch_size)])
  27. yield torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
  28. def train(model, device, args):
  29. os.makedirs(args.save_path, exist_ok=True)
  30. # 2. 加载数据集
  31. train_data = get_memmap_dataset(args.train_data_path)
  32. val_data = get_memmap_dataset(args.val_data_path)
  33. # 3. 构建优化器
  34. optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
  35. # 4. 恢复断点
  36. start_iter = 0
  37. if args.resume_checkpoint:
  38. print(f"Resuming from checkpoint {args.resume_checkpoint}")
  39. resume_ckpt_path = pathlib.Path(HydraConfig.get().runtime.output_dir) / f"{args.save_path}/ckpt_iter{args.resume_checkpoint}.pt"
  40. start_iter = run_load_checkpoint(resume_ckpt_path, model, optimizer)
  41. print(f"Resumed at iteration {start_iter} from path {resume_ckpt_path}")
  42. # 5. 训练loop
  43. pbar = tqdm(range(start_iter, args.train_steps), desc="Training", leave=False)
  44. for iteration in pbar:
  45. model.train()
  46. x, y = get_batch(train_data, args.batch_size, args.context_length)
  47. x, y = x.to(device), y.to(device)
  48. logits, _ = model(x)
  49. loss = F.cross_entropy(
  50. logits.reshape(-1, logits.shape[-1]),
  51. y.reshape(-1)
  52. )
  53. optimizer.zero_grad()
  54. loss.backward()
  55. run_gradient_clipping(model.parameters(), args.clip_grad_norm)
  56. # 更新学习率
  57. lr = run_get_lr_cosine_schedule(
  58. iteration, args.lr, args.min_lr, args.warmup_iters, args.cosine_iters
  59. )
  60. for param_group in optimizer.param_groups:
  61. param_group['lr'] = lr
  62. optimizer.step()
  63. pbar.set_postfix(loss=loss.item(), lr=lr)
  64. mlflow.log_metric("loss", loss.item(), step=iteration)
  65. mlflow.log_metric("lr", lr, step=iteration)
  66. # 验证
  67. if (iteration+1) % args.val_interval == 0:
  68. model.eval()
  69. with torch.no_grad():
  70. val_losses = []
  71. count = 0
  72. for x_val, y_val in memmap_val_iterator(val_data, args.batch_size, args.context_length):
  73. x_val, y_val = x_val.to(device), y_val.to(device)
  74. val_logits, _ = model(x_val)
  75. val_loss = F.cross_entropy(
  76. val_logits.reshape(-1, val_logits.shape[-1]),
  77. y_val.reshape(-1)
  78. )
  79. val_losses.append(val_loss.item())
  80. count += 1
  81. if count >= args.val_batches:
  82. break
  83. val_loss_mean = np.mean(val_losses)
  84. mlflow.log_metric("val_loss", val_loss_mean, step=iteration)
  85. print(f"iter {iteration+1:05d}: VALID loss = {val_loss_mean:.4f}")
  86. # 保存
  87. if (iteration+1) % args.save_interval == 0:
  88. ckpt_name = os.path.join(args.save_path, f"ckpt_iter{iteration+1}.pt")
  89. run_save_checkpoint(model, optimizer, iteration+1, ckpt_name)
  90. print(f"Checkpoint saved to {ckpt_name}")