lamb.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. """ PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb
  2. This optimizer code was adapted from the following (starting with latest)
  3. * https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
  4. * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
  5. * https://github.com/cybertronai/pytorch-lamb
  6. Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
  7. similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
  8. In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
  9. References for added functionality:
  10. Cautious Optimizers: https://arxiv.org/abs/2411.16085
  11. Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
  12. Original copyrights for above sources are below.
  13. Modifications Copyright 2021 Ross Wightman
  14. """
  15. # Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
  16. # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
  17. #
  18. # Licensed under the Apache License, Version 2.0 (the "License");
  19. # you may not use this file except in compliance with the License.
  20. # You may obtain a copy of the License at
  21. #
  22. # http://www.apache.org/licenses/LICENSE-2.0
  23. #
  24. # Unless required by applicable law or agreed to in writing, software
  25. # distributed under the License is distributed on an "AS IS" BASIS,
  26. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  27. # See the License for the specific language governing permissions and
  28. # limitations under the License.
  29. # MIT License
  30. #
  31. # Copyright (c) 2019 cybertronai
  32. #
  33. # Permission is hereby granted, free of charge, to any person obtaining a copy
  34. # of this software and associated documentation files (the "Software"), to deal
  35. # in the Software without restriction, including without limitation the rights
  36. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  37. # copies of the Software, and to permit persons to whom the Software is
  38. # furnished to do so, subject to the following conditions:
  39. #
  40. # The above copyright notice and this permission notice shall be included in all
  41. # copies or substantial portions of the Software.
  42. #
  43. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  44. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  45. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  46. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  47. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  48. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  49. # SOFTWARE.
  50. import math
  51. from typing import Optional, Tuple
  52. import torch
  53. from torch.optim import Optimizer
  54. from ._types import ParamsT
  55. class Lamb(Optimizer):
  56. """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
  57. reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
  58. LAMB was proposed in:
  59. - Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962
  60. - On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
  61. Args:
  62. params: Iterable of parameters to optimize or dicts defining parameter groups.
  63. lr: Learning rate
  64. betas: Coefficients used for computing running averages of gradient and its norm.
  65. eps: Term added to the denominator to improve numerical stability.
  66. weight_decay: Weight decay
  67. grad_averaging: Whether apply (1-beta2) to grad when calculating running averages of gradient.
  68. max_grad_norm: Value used to clip global grad norm.
  69. trust_clip: Enable LAMBC trust ratio clipping.
  70. always_adapt: Apply adaptive learning rate to 0.0 weight decay parameter.
  71. caution: Apply caution.
  72. decoupled: apply decoupled weight decay
  73. corrected_weight_decay: apply corrected weight decay (lr**2 / max_lr) when using decoupled_decay
  74. """
  75. def __init__(
  76. self,
  77. params: ParamsT,
  78. lr: float = 1e-3,
  79. bias_correction: bool = True,
  80. betas: Tuple[float, float] = (0.9, 0.999),
  81. eps: float = 1e-6,
  82. weight_decay: float = 0.01,
  83. grad_averaging: bool = True,
  84. max_grad_norm: Optional[float] = 1.0,
  85. trust_clip: bool = False,
  86. always_adapt: bool = False,
  87. caution: bool = False,
  88. decoupled_decay: bool = False,
  89. corrected_weight_decay: bool = False,
  90. ):
  91. defaults = dict(
  92. lr=lr,
  93. bias_correction=bias_correction,
  94. betas=betas,
  95. eps=eps,
  96. weight_decay=weight_decay,
  97. grad_averaging=grad_averaging,
  98. max_grad_norm=max_grad_norm,
  99. trust_clip=trust_clip,
  100. always_adapt=always_adapt,
  101. caution=caution,
  102. decoupled_decay=decoupled_decay,
  103. corrected_weight_decay=corrected_weight_decay,
  104. )
  105. super().__init__(params, defaults)
  106. def __setstate__(self, state):
  107. super().__setstate__(state)
  108. for group in self.param_groups:
  109. group.setdefault('caution', False)
  110. group.setdefault('decoupled_decay', False)
  111. group.setdefault('corrected_weight_decay', False)
  112. def _get_clip_grad_norm(self):
  113. max_grad_norm = self.defaults['max_grad_norm']
  114. if max_grad_norm is None:
  115. return None
  116. norms = []
  117. for group in self.param_groups:
  118. for p in group['params']:
  119. if p.grad is None:
  120. continue
  121. grad = p.grad
  122. if grad.is_sparse:
  123. raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instead.')
  124. norms.append(torch.linalg.vector_norm(grad))
  125. global_norm = torch.linalg.vector_norm(torch.stack(norms))
  126. clip_global_norm = (global_norm / max_grad_norm).clamp_(min=1.0)
  127. return clip_global_norm
  128. @torch.no_grad()
  129. def step(self, closure=None):
  130. """Performs a single optimization step.
  131. Arguments:
  132. closure (callable, optional): A closure that reevaluates the model
  133. and returns the loss.
  134. """
  135. loss = None
  136. if closure is not None:
  137. with torch.enable_grad():
  138. loss = closure()
  139. clip_grad_norm = self._get_clip_grad_norm() # None if disabled
  140. for group in self.param_groups:
  141. bias_correction = 1 if group['bias_correction'] else 0
  142. beta1, beta2 = group['betas']
  143. grad_averaging = 1 if group['grad_averaging'] else 0
  144. beta3 = 1 - beta1 if grad_averaging else 1.0
  145. # assume same step across group now to simplify things
  146. # per parameter step can be easily support by making it tensor, or pass list into kernel
  147. if 'step' in group:
  148. group['step'] += 1
  149. else:
  150. group['step'] = 1
  151. if bias_correction:
  152. bias_correction1 = 1 - beta1 ** group['step']
  153. bias_correction2 = 1 - beta2 ** group['step']
  154. else:
  155. bias_correction1, bias_correction2 = 1.0, 1.0
  156. for p in group['params']:
  157. if p.grad is None:
  158. continue
  159. grad = p.grad
  160. if clip_grad_norm is not None:
  161. grad.div_(clip_grad_norm)
  162. state = self.state[p]
  163. # State initialization
  164. if len(state) == 0:
  165. # Exponential moving average of gradient valuesa
  166. state['exp_avg'] = torch.zeros_like(p)
  167. # Exponential moving average of squared gradient values
  168. state['exp_avg_sq'] = torch.zeros_like(p)
  169. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  170. # Decay the first and second moment running average coefficient
  171. exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
  172. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
  173. denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
  174. update = (exp_avg / bias_correction1).div_(denom)
  175. if group['caution']:
  176. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  177. mask = (update * grad > 0).to(grad.dtype)
  178. mask.div_(mask.mean().clamp_(min=1e-3))
  179. update.mul_(mask)
  180. weight_decay = group['weight_decay']
  181. if weight_decay != 0:
  182. if group.get('decoupled_decay', False):
  183. if group['corrected_weight_decay']:
  184. wd_scale = group['lr'] ** 2 / self.defaults['lr']
  185. else:
  186. wd_scale = group['lr']
  187. p.add_(p, alpha=-wd_scale * weight_decay)
  188. else:
  189. update.add_(p, alpha=weight_decay)
  190. if weight_decay != 0 or group['always_adapt']:
  191. # Layer-wise LR adaptation. By default, skip adaptation on parameters that are
  192. # excluded from weight decay, unless always_adapt == True, then always enabled.
  193. w_norm = p.norm(2.0)
  194. g_norm = update.norm(2.0)
  195. trust_ratio = w_norm / g_norm
  196. # FIXME nested where required since logical and/or not working in PT XLA
  197. # Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero
  198. trust_ratio = torch.where(
  199. w_norm > 0,
  200. torch.where(g_norm > 0, trust_ratio, 1.0),
  201. 1.0,
  202. )
  203. if group['trust_clip']:
  204. # LAMBC trust clipping, upper bound fixed at one
  205. trust_ratio = torch.clamp(trust_ratio, max=1.0)
  206. update.mul_(trust_ratio)
  207. p.add_(update, alpha=-group['lr'])
  208. return loss