mars.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. """ PyTorch MARS Optimizer
  2. Code simplified from https://github.com/AGI-Arena/MARS
  3. Paper: MARS: Unleashing the Power of Variance Reduction for Training Large Models - https://arxiv.org/abs/2411.10438
  4. @article{yuan2024mars,
  5. title={MARS: Unleashing the Power of Variance Reduction for Training Large Models},
  6. author={Yuan, Huizhuo and Liu, Yifeng and Wu, Shuang and Zhou, Xun and Gu, Quanquan},
  7. journal={arXiv preprint arXiv:2411.10438},
  8. year={2024}
  9. }
  10. """
  11. # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
  12. # SPDX-License-Identifier: Apache-2.0
  13. import math
  14. from typing import Optional, Tuple
  15. import torch
  16. from torch.optim.optimizer import Optimizer
  17. from ._types import ParamsT
  18. def _mars_single_tensor_step(
  19. p: torch.Tensor,
  20. grad: torch.Tensor,
  21. exp_avg: torch.Tensor,
  22. exp_avg_sq: torch.Tensor,
  23. lr: float,
  24. weight_decay: float,
  25. beta1: float,
  26. beta2: float,
  27. last_grad: torch.Tensor,
  28. eps: float,
  29. step: int,
  30. gamma: float,
  31. mars_type: str,
  32. is_grad_2d: bool,
  33. optimize_1d: bool,
  34. lr_1d_factor: bool,
  35. betas_1d: Tuple[float, float],
  36. caution: bool,
  37. ):
  38. # optimize_1d ==> use MARS for 1d param, else use AdamW
  39. if optimize_1d or is_grad_2d:
  40. one_minus_beta1 = 1. - beta1
  41. if step == 1:
  42. # this is a timm addition, making first step more consistent when no grad history, otherwise tests fail
  43. c_t = grad
  44. else:
  45. c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
  46. c_t_norm = torch.norm(c_t)
  47. if c_t_norm > 1.:
  48. c_t = c_t / c_t_norm
  49. exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)
  50. if caution:
  51. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  52. mask = (exp_avg * grad > 0).to(grad.dtype)
  53. mask.div_(mask.mean().clamp_(min=1e-3))
  54. exp_avg = exp_avg * mask
  55. if mars_type == "adamw":
  56. exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
  57. bias_correction1 = 1.0 - beta1 ** step
  58. bias_correction2 = 1.0 - beta2 ** step
  59. denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
  60. update = p * weight_decay + (exp_avg / bias_correction1).div_(denom)
  61. elif mars_type == "lion":
  62. update = p * weight_decay + exp_avg.sign()
  63. else:
  64. assert False
  65. p.add_(update, alpha=-lr)
  66. else:
  67. beta1_1d, beta2_1d = betas_1d
  68. exp_avg.mul_(beta1_1d).add_(grad, alpha=1. - beta1_1d)
  69. exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1. - beta2_1d)
  70. bias_correction1 = 1.0 - beta1_1d ** step
  71. bias_correction2 = 1.0 - beta2_1d ** step
  72. denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
  73. if caution:
  74. mask = (exp_avg * grad > 0).to(grad.dtype)
  75. mask.div_(mask.mean().clamp_(min=1e-3))
  76. exp_avg = exp_avg * mask
  77. update = p * weight_decay + (exp_avg / bias_correction1).div_(denom)
  78. p.add_(update, alpha=-(lr * lr_1d_factor))
  79. return exp_avg, exp_avg_sq
  80. class Mars(Optimizer):
  81. """ MARS Optimizer
  82. Paper: MARS: Unleashing the Power of Variance Reduction for Training Large Models
  83. https://arxiv.org/abs/2411.10438
  84. """
  85. def __init__(
  86. self,
  87. params: ParamsT,
  88. lr: float = 3e-3,
  89. betas: Tuple[float, float] = (0.9, 0.99),
  90. eps: float = 1e-8,
  91. weight_decay: float = 0.,
  92. gamma: float = 0.025,
  93. mars_type: str = "adamw",
  94. optimize_1d: bool = False,
  95. lr_1d_factor: float = 1.0,
  96. betas_1d: Optional[Tuple[float, float]] = None,
  97. caution: bool = False
  98. ):
  99. if not 0.0 <= lr:
  100. raise ValueError("Invalid learning rate: {}".format(lr))
  101. if not 0.0 <= eps:
  102. raise ValueError("Invalid epsilon value: {}".format(eps))
  103. if not 0.0 <= betas[0] < 1.0:
  104. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  105. if not 0.0 <= betas[1] < 1.0:
  106. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  107. assert mars_type in ["adamw", "lion"], "MARS type not supported"
  108. defaults = dict(
  109. lr=lr,
  110. betas=betas,
  111. eps=eps,
  112. weight_decay=weight_decay,
  113. mars_type=mars_type,
  114. gamma=gamma,
  115. optimize_1d=optimize_1d,
  116. lr_1d_factor=lr_1d_factor,
  117. betas_1d=betas_1d or betas,
  118. caution=caution,
  119. )
  120. super(Mars, self).__init__(params, defaults)
  121. def __setstate__(self, state):
  122. super(Mars, self).__setstate__(state)
  123. for group in self.param_groups:
  124. group.setdefault('caution', False)
  125. @torch.no_grad()
  126. def step(self, closure=None):
  127. """Performs a single optimization step.
  128. Arguments:
  129. closure (callable, optional): A closure that reevaluates the model
  130. and returns the loss.
  131. """
  132. loss = None
  133. if closure is not None:
  134. with torch.enable_grad():
  135. loss = closure()
  136. for group in self.param_groups:
  137. for p in group['params']:
  138. if p.grad is None:
  139. continue
  140. grad = p.grad
  141. if grad.is_sparse:
  142. raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
  143. state = self.state[p]
  144. # State initialization
  145. if len(state) <= 1:
  146. state['step'] = 0
  147. # Exponential moving average of gradient values
  148. state['exp_avg'] = torch.zeros_like(p)
  149. # Last Gradient
  150. state['last_grad'] = torch.zeros_like(p)
  151. # Exponential moving average of squared gradient values
  152. state['exp_avg_sq'] = torch.zeros_like(p)
  153. state['step'] += 1
  154. step = state['step']
  155. exp_avg = state['exp_avg']
  156. exp_avg_sq = state['exp_avg_sq']
  157. last_grad = state['last_grad']
  158. lr = group['lr']
  159. wd = group['weight_decay']
  160. beta1, beta2 = group['betas']
  161. is_grad_2d = grad.ndim >= 2
  162. # FIXME add multi-tensor (if usage warrants), make more standard
  163. _mars_single_tensor_step(
  164. p,
  165. grad,
  166. exp_avg,
  167. exp_avg_sq,
  168. lr,
  169. wd,
  170. beta1,
  171. beta2,
  172. last_grad,
  173. group['eps'],
  174. step,
  175. group['gamma'],
  176. mars_type=group['mars_type'],
  177. is_grad_2d=is_grad_2d,
  178. optimize_1d=group['optimize_1d'],
  179. lr_1d_factor=group['lr_1d_factor'],
  180. betas_1d=group['betas_1d'],
  181. caution=group['caution'],
  182. )
  183. state['last_grad'] = grad
  184. return loss