adapters.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from __future__ import annotations
  2. import os
  3. import torch
  4. import numpy as np
  5. from typing import IO, BinaryIO
  6. from collections.abc import Iterable
  7. def run_gradient_clipping(parameters: Iterable[torch.nn.Parameter], max_l2_norm: float) -> None:
  8. """Given a set of parameters, clip their combined gradients to have l2 norm at most max_l2_norm.
  9. Args:
  10. parameters (Iterable[torch.nn.Parameter]): collection of trainable parameters.
  11. max_l2_norm (float): a positive value containing the maximum l2-norm.
  12. The gradients of the parameters (parameter.grad) should be modified in-place.
  13. """
  14. # Filter parameters with gradients
  15. parameters_with_grad = [p for p in parameters if p.grad is not None]
  16. if len(parameters_with_grad) == 0:
  17. return
  18. # Calculate total L2 norm of all gradients
  19. total_norm = torch.sqrt(sum(torch.sum(p.grad.pow(2)) for p in parameters_with_grad))
  20. # Calculate clipping coefficient
  21. clip_coef = max_l2_norm / (total_norm + 1e-6) # Add small value to avoid division by zero
  22. # If total norm exceeds max_norm, scale down all gradients
  23. if clip_coef < 1.0:
  24. for p in parameters_with_grad:
  25. p.grad.mul_(clip_coef)
  26. def get_adamw_cls() -> type[torch.optim.Optimizer]:
  27. """
  28. Returns a torch.optim.Optimizer that implements AdamW.
  29. """
  30. class AdamW(torch.optim.Optimizer):
  31. def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2):
  32. if not 0.0 <= lr:
  33. raise ValueError(f"Invalid learning rate: {lr}")
  34. if not 0.0 <= eps:
  35. raise ValueError(f"Invalid epsilon value: {eps}")
  36. if not 0.0 <= betas[0] < 1.0:
  37. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  38. if not 0.0 <= betas[1] < 1.0:
  39. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  40. defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
  41. super(AdamW, self).__init__(params, defaults)
  42. def step(self, closure=None):
  43. loss = None
  44. if closure is not None:
  45. loss = closure()
  46. for group in self.param_groups:
  47. for p in group['params']:
  48. if p.grad is None:
  49. continue
  50. # Perform stepweight decay
  51. p.data.mul_(1 - group['lr'] * group['weight_decay'])
  52. # Get parameter-specific state
  53. state = self.state[p]
  54. # State initialization
  55. if len(state) == 0:
  56. state['step'] = 0
  57. state['exp_avg'] = torch.zeros_like(p.data)
  58. state['exp_avg_sq'] = torch.zeros_like(p.data)
  59. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  60. beta1, beta2 = group['betas']
  61. step_size = group['lr']
  62. eps = group['eps']
  63. # Update state
  64. state['step'] += 1
  65. grad = p.grad.data
  66. # Decay the first and second moment running average coefficient
  67. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  68. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
  69. denom = exp_avg_sq.sqrt().add_(eps)
  70. bias_correction1 = 1 - beta1 ** state['step']
  71. bias_correction2 = 1 - beta2 ** state['step']
  72. step_size = step_size * (bias_correction2 ** 0.5) / bias_correction1
  73. p.data.addcdiv_(exp_avg, denom, value=-step_size)
  74. return loss
  75. return AdamW
  76. def run_get_lr_cosine_schedule(
  77. it: int,
  78. max_learning_rate: float,
  79. min_learning_rate: float,
  80. warmup_iters: int,
  81. cosine_cycle_iters: int,
  82. ):
  83. """
  84. Given the parameters of a cosine learning rate decay schedule (with linear
  85. warmup) and an iteration number, return the learning rate at the given
  86. iteration under the specified schedule.
  87. Args:
  88. it (int): Iteration number to get learning rate for.
  89. max_learning_rate (float): alpha_max, the maximum learning rate for
  90. cosine learning rate schedule (with warmup).
  91. min_learning_rate (float): alpha_min, the minimum / final learning rate for
  92. the cosine learning rate schedule (with warmup).
  93. warmup_iters (int): T_w, the number of iterations to linearly warm-up
  94. the learning rate.
  95. cosine_cycle_iters (int): T_c, the number of cosine annealing iterations.
  96. Returns:
  97. Learning rate at the given iteration under the specified schedule.
  98. """
  99. if it < warmup_iters:
  100. # Warm-up 阶段:线性增加学习率
  101. lr = (it / warmup_iters) * max_learning_rate
  102. elif it <= cosine_cycle_iters:
  103. # Cosine Annealing 阶段:余弦函数衰减
  104. t = it - warmup_iters
  105. T = cosine_cycle_iters - warmup_iters
  106. cos_value = np.cos(np.pi * t / T)
  107. lr = min_learning_rate + 0.5 * (max_learning_rate - min_learning_rate) * (1 + cos_value)
  108. else:
  109. # Post-annealing 阶段:学习率保持最小值
  110. lr = min_learning_rate
  111. return lr
  112. def run_save_checkpoint(
  113. model: torch.nn.Module,
  114. optimizer: torch.optim.Optimizer,
  115. iteration: int,
  116. out: str | os.PathLike | BinaryIO | IO[bytes],
  117. ):
  118. """
  119. Given a model, optimizer, and an iteration number, serialize them to disk.
  120. Args:
  121. model (torch.nn.Module): Serialize the state of this model.
  122. optimizer (torch.optim.Optimizer): Serialize the state of this optimizer.
  123. iteration (int): Serialize this value, which represents the number of training iterations
  124. we've completed.
  125. out (str | os.PathLike | BinaryIO | IO[bytes]): Path or file-like object to serialize the model, optimizer, and iteration to.
  126. """
  127. checkpoint = {
  128. 'model_state_dict': model.state_dict(),
  129. 'optimizer_state_dict': optimizer.state_dict(),
  130. 'iteration': iteration
  131. }
  132. if isinstance(out, (str, os.PathLike)):
  133. with open(out, 'wb') as f:
  134. torch.save(checkpoint, f)
  135. else:
  136. torch.save(checkpoint, out)
  137. def run_load_checkpoint(
  138. src: str | os.PathLike | BinaryIO | IO[bytes],
  139. model: torch.nn.Module,
  140. optimizer: torch.optim.Optimizer,
  141. ):
  142. """
  143. Given a serialized checkpoint (path or file-like object), restore the
  144. serialized state to the given model and optimizer.
  145. Return the number of iterations that we previously serialized in
  146. the checkpoint.
  147. Args:
  148. src (str | os.PathLike | BinaryIO | IO[bytes]): Path or file-like object to serialized checkpoint.
  149. model (torch.nn.Module): Restore the state of this model.
  150. optimizer (torch.optim.Optimizer): Restore the state of this optimizer.
  151. Returns:
  152. int: the previously-serialized number of iterations.
  153. """
  154. if isinstance(src, (str, os.PathLike)):
  155. with open(src, 'rb') as f:
  156. checkpoint = torch.load(f, weights_only=False)
  157. else:
  158. checkpoint = torch.load(src)
  159. model.load_state_dict(checkpoint['model_state_dict'])
  160. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  161. return checkpoint['iteration']