lion.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. """ Lion Optimizer
  2. Paper: `Symbolic Discovery of Optimization Algorithms` - https://arxiv.org/abs/2302.06675
  3. Original Impl: https://github.com/google/automl/tree/master/lion
  4. References for added functionality:
  5. Cautious Optimizers: https://arxiv.org/abs/2411.16085
  6. Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
  7. """
  8. # Copyright 2023 Google Research. All Rights Reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. # ==============================================================================
  22. from typing import List, Optional, Tuple
  23. import torch
  24. from torch.optim.optimizer import Optimizer
  25. from ._types import ParamsT
  26. class Lion(Optimizer):
  27. r"""Implements Lion algorithm."""
  28. def __init__(
  29. self,
  30. params: ParamsT,
  31. lr: float = 1e-4,
  32. betas: Tuple[float, float] = (0.9, 0.99),
  33. weight_decay: float = 0.0,
  34. caution: bool = False,
  35. corrected_weight_decay: bool = False,
  36. maximize: bool = False,
  37. foreach: Optional[bool] = None,
  38. ):
  39. """Initialize the hyperparameters.
  40. Args:
  41. params: iterable of parameters to optimize or dicts defining parameter groups
  42. lr: learning rate
  43. betas: coefficients used for computing running averages of gradient and its square
  44. weight_decay: weight decay coefficient
  45. caution: apply caution
  46. corrected_weight_decay: apply corrected weight decay (lr**2 / max_lr)
  47. """
  48. if not 0.0 <= lr:
  49. raise ValueError('Invalid learning rate: {}'.format(lr))
  50. if not 0.0 <= betas[0] < 1.0:
  51. raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
  52. if not 0.0 <= betas[1] < 1.0:
  53. raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
  54. defaults = dict(
  55. lr=lr,
  56. betas=betas,
  57. weight_decay=weight_decay,
  58. caution=caution,
  59. corrected_weight_decay=corrected_weight_decay,
  60. foreach=foreach,
  61. maximize=maximize,
  62. )
  63. super().__init__(params, defaults)
  64. def __setstate__(self, state):
  65. super().__setstate__(state)
  66. for group in self.param_groups:
  67. group.setdefault('caution', False)
  68. group.setdefault('corrected_weight_decay', False)
  69. group.setdefault('maximize', False)
  70. group.setdefault('foreach', None)
  71. @torch.no_grad()
  72. def step(self, closure=None):
  73. """Performs a single optimization step.
  74. Args:
  75. closure: A closure that reevaluates the model and returns the loss.
  76. Returns:
  77. the loss.
  78. """
  79. loss = None
  80. if closure is not None:
  81. with torch.enable_grad():
  82. loss = closure()
  83. for group in self.param_groups:
  84. params_with_grad = []
  85. grads = []
  86. exp_avgs = []
  87. beta1, beta2 = group['betas']
  88. for p in group['params']:
  89. if p.grad is None:
  90. continue
  91. params_with_grad.append(p)
  92. if p.grad.is_sparse:
  93. raise RuntimeError('Lion does not support sparse gradients')
  94. grads.append(p.grad)
  95. state = self.state[p]
  96. # State initialization
  97. if len(state) == 0:
  98. state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  99. exp_avgs.append(state['exp_avg'])
  100. lion(
  101. params_with_grad,
  102. grads,
  103. exp_avgs,
  104. beta1=beta1,
  105. beta2=beta2,
  106. lr=group['lr'],
  107. weight_decay=group['weight_decay'],
  108. caution=group['caution'],
  109. maximize=group['maximize'],
  110. foreach=group['foreach'],
  111. max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None,
  112. )
  113. return loss
  114. def lion(
  115. params: List[torch.Tensor],
  116. grads: List[torch.Tensor],
  117. exp_avgs: List[torch.Tensor],
  118. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  119. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  120. maximize: bool = False,
  121. foreach: bool = None,
  122. *,
  123. beta1: float,
  124. beta2: float,
  125. lr: float,
  126. weight_decay: float,
  127. caution: bool,
  128. max_lr: Optional[float] = None,
  129. ):
  130. r"""Functional API that performs Lion algorithm computation.
  131. """
  132. if foreach is None:
  133. try:
  134. # cannot do foreach if this overload doesn't exist when caution enabled
  135. foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum_.overloads()
  136. except Exception:
  137. foreach = False
  138. if foreach and torch.jit.is_scripting():
  139. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  140. if foreach and not torch.jit.is_scripting():
  141. func = _multi_tensor_lion
  142. else:
  143. func = _single_tensor_lion
  144. func(
  145. params,
  146. grads,
  147. exp_avgs,
  148. beta1=beta1,
  149. beta2=beta2,
  150. lr=lr,
  151. weight_decay=weight_decay,
  152. caution=caution,
  153. maximize=maximize,
  154. max_lr=max_lr,
  155. )
  156. def _single_tensor_lion(
  157. params: List[torch.Tensor],
  158. grads: List[torch.Tensor],
  159. exp_avgs: List[torch.Tensor],
  160. *,
  161. beta1: float,
  162. beta2: float,
  163. lr: float,
  164. weight_decay: float,
  165. caution: bool,
  166. maximize: bool,
  167. max_lr: Optional[float],
  168. ):
  169. for i, param in enumerate(params):
  170. grad = grads[i] if not maximize else -grads[i]
  171. exp_avg = exp_avgs[i]
  172. if torch.is_complex(param):
  173. grad = torch.view_as_real(grad)
  174. exp_avg = torch.view_as_real(exp_avg)
  175. param = torch.view_as_real(param)
  176. # Perform stepweight decay
  177. wd_scale = lr if max_lr is None else lr ** 2 / max_lr
  178. param.mul_(1 - wd_scale * weight_decay)
  179. # Weight update
  180. update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1).sign_()
  181. if caution:
  182. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  183. mask = (update * grad > 0).to(grad.dtype)
  184. mask.div_(mask.mean().clamp_(min=1e-3))
  185. update.mul_(mask)
  186. param.add_(update, alpha=-lr)
  187. # Decay the momentum running average coefficient
  188. exp_avg.lerp_(grad, 1 - beta2)
  189. def _multi_tensor_lion(
  190. params: List[torch.Tensor],
  191. grads: List[torch.Tensor],
  192. exp_avgs: List[torch.Tensor],
  193. *,
  194. beta1: float,
  195. beta2: float,
  196. lr: float,
  197. weight_decay: float,
  198. caution: bool,
  199. maximize: bool,
  200. max_lr: Optional[float],
  201. ):
  202. if len(params) == 0:
  203. return
  204. if maximize:
  205. grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment]
  206. grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
  207. exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs]
  208. params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params]
  209. # Perform stepweight decay
  210. wd_scale = lr if max_lr is None else lr ** 2 / max_lr
  211. torch._foreach_mul_(params, 1 - wd_scale * weight_decay)
  212. # Weight update
  213. updates = torch._foreach_mul(exp_avgs, beta1)
  214. torch._foreach_add_(updates, grads, alpha=1 - beta1)
  215. updates = [u.sign_() for u in updates]
  216. if caution:
  217. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  218. masks = torch._foreach_mul(updates, grads)
  219. masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
  220. mask_scale = [m.mean() for m in masks]
  221. torch._foreach_maximum_(mask_scale, 1e-3)
  222. torch._foreach_div_(masks, mask_scale)
  223. torch._foreach_mul_(updates, masks)
  224. torch._foreach_add_(params, updates, alpha=-lr)
  225. # Decay the momentum running average coefficient
  226. torch._foreach_mul_(exp_avgs, beta2)
  227. torch._foreach_add_(exp_avgs, grads, alpha=1 - beta2)