adafactor.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. """ Adafactor Optimizer
  2. Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
  3. Modified by Ross Wightman to fix some issues with factorization dims for non nn.Linear layers
  4. Original header/copyright below.
  5. """
  6. # Copyright (c) Facebook, Inc. and its affiliates.
  7. #
  8. # This source code is licensed under the MIT license found in the
  9. # LICENSE file in the root directory of this source tree.
  10. import math
  11. from typing import Optional, Tuple
  12. import torch
  13. from ._types import ParamsT
  14. class Adafactor(torch.optim.Optimizer):
  15. """Implements Adafactor algorithm.
  16. This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
  17. (see https://arxiv.org/abs/1804.04235)
  18. Note that this optimizer internally adjusts the learning rate depending on the
  19. *scale_parameter*, *relative_step* and *warmup_init* options.
  20. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
  21. `relative_step=False`.
  22. Ags:
  23. params: iterable of parameters to optimize or dicts defining parameter groups
  24. lr: external learning rate
  25. eps: regularization constants for square gradient and parameter scale respectively
  26. eps_scale: regularization constants for parameter scale respectively
  27. clip_threshold: threshold of root-mean-square of final gradient update
  28. decay_rate: coefficient used to compute running averages of square gradient
  29. beta1: coefficient used for computing running averages of gradient
  30. weight_decay: weight decay
  31. scale_parameter: if True, learning rate is scaled by root-mean-square of parameter
  32. warmup_init: time-dependent learning rate computation depends on whether warm-up initialization is being used
  33. """
  34. def __init__(
  35. self,
  36. params: ParamsT,
  37. lr: Optional[float] = None,
  38. eps: float = 1e-30,
  39. eps_scale: float = 1e-3,
  40. clip_threshold: float = 1.0,
  41. decay_rate: float = -0.8,
  42. betas: Optional[Tuple[float, float]] = None,
  43. weight_decay: float = 0.0,
  44. scale_parameter: bool = True,
  45. warmup_init: bool = False,
  46. min_dim_size_to_factor: int = 16,
  47. caution: bool = False,
  48. ):
  49. relative_step = not lr
  50. if warmup_init and not relative_step:
  51. raise ValueError('warmup_init requires relative_step=True')
  52. beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
  53. defaults = dict(
  54. lr=lr,
  55. eps=eps,
  56. eps_scale=eps_scale,
  57. clip_threshold=clip_threshold,
  58. decay_rate=decay_rate,
  59. beta1=beta1,
  60. weight_decay=weight_decay,
  61. scale_parameter=scale_parameter,
  62. relative_step=relative_step,
  63. warmup_init=warmup_init,
  64. min_dim_size_to_factor=min_dim_size_to_factor,
  65. caution=caution,
  66. )
  67. super(Adafactor, self).__init__(params, defaults)
  68. def __setstate__(self, state):
  69. super().__setstate__(state)
  70. for group in self.param_groups:
  71. group.setdefault('caution', False)
  72. group.setdefault('min_dim_size_to_factor', 16)
  73. @staticmethod
  74. def _get_lr(param_group, param_state):
  75. if param_group['relative_step']:
  76. min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
  77. lr_t = min(min_step, 1.0 / math.sqrt(param_state['step']))
  78. param_scale = 1.0
  79. if param_group['scale_parameter']:
  80. param_scale = max(param_group['eps_scale'], param_state['RMS'])
  81. param_group['lr'] = lr_t * param_scale
  82. return param_group['lr']
  83. @staticmethod
  84. def _get_options(param_group, param_shape, min_size_to_factor=16):
  85. use_first_moment = param_group['beta1'] is not None
  86. factored = None
  87. ndim = len(param_shape)
  88. # Use a simple heuristic to pick factorization row & col, note other PyTorch impl tend to
  89. # always use -2, -1 BUT this will not pick correct dims for convolutions. This is a simple
  90. # approach that should work in most cases, compare to the slightly more involved approach
  91. # in AdafactorBigVision that sorts dims by size, please report if wrong dims chosen.
  92. if ndim > 2 and param_shape[0] > min_size_to_factor and param_shape[1] > min_size_to_factor:
  93. # nD convs in torch are ND + 2 dim weights with leading in/out chs
  94. factored = 0, 1
  95. elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
  96. # if the criteria above didn't match, test trailing dims for eligibility as per original impl
  97. factored = ndim - 2, ndim - 1
  98. return factored, use_first_moment
  99. @staticmethod
  100. def _rms(tensor):
  101. return tensor.norm(2) / (tensor.numel() ** 0.5)
  102. def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row):
  103. # from our dim heuristic, always dim_col < dim_row, so col reduction dim for factored row = dim_col
  104. r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=dim_col, keepdim=True)).rsqrt_().unsqueeze(dim_row)
  105. c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt()
  106. return torch.mul(r_factor, c_factor)
  107. @torch.no_grad()
  108. def step(self, closure=None):
  109. """Performs a single optimization step.
  110. Arguments:
  111. closure (callable, optional): A closure that reevaluates the model and returns the loss.
  112. """
  113. loss = None
  114. if closure is not None:
  115. with torch.enable_grad():
  116. loss = closure()
  117. for group in self.param_groups:
  118. for p in group['params']:
  119. if p.grad is None:
  120. continue
  121. grad = p.grad
  122. if grad.dtype in {torch.float16, torch.bfloat16}:
  123. grad = grad.float()
  124. if grad.is_sparse:
  125. raise RuntimeError('Adafactor does not support sparse gradients.')
  126. state = self.state[p]
  127. factored_dims, use_first_moment = self._get_options(
  128. group,
  129. grad.shape,
  130. min_size_to_factor=group['min_dim_size_to_factor'],
  131. )
  132. # State Initialization
  133. if len(state) == 0:
  134. state['step'] = 0
  135. if use_first_moment:
  136. # Exponential moving average of gradient values
  137. state['exp_avg'] = torch.zeros_like(grad)
  138. if factored_dims is not None:
  139. dim_col, dim_row = factored_dims
  140. def _remove_dim(shape, dim):
  141. return shape[:dim] + shape[dim + 1:]
  142. state['exp_avg_sq_row'] = torch.zeros(_remove_dim(grad.shape, dim_row)).to(grad)
  143. state['exp_avg_sq_col'] = torch.zeros(_remove_dim(grad.shape, dim_col)).to(grad)
  144. else:
  145. state['exp_avg_sq'] = torch.zeros_like(grad)
  146. state['RMS'] = 0
  147. else:
  148. if use_first_moment:
  149. state['exp_avg'] = state['exp_avg'].to(grad)
  150. if factored_dims is not None:
  151. state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
  152. state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
  153. else:
  154. state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
  155. p_fp32 = p
  156. if p.dtype in {torch.float16, torch.bfloat16}:
  157. p_fp32 = p_fp32.float()
  158. state['step'] += 1
  159. state['RMS'] = self._rms(p_fp32)
  160. lr_t = self._get_lr(group, state)
  161. beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
  162. update = grad ** 2 + group['eps']
  163. if factored_dims is not None:
  164. dim_col, dim_row = factored_dims
  165. exp_avg_sq_row = state['exp_avg_sq_row']
  166. exp_avg_sq_col = state['exp_avg_sq_col']
  167. exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=dim_row), alpha=1.0 - beta2t)
  168. exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=dim_col), alpha=1.0 - beta2t)
  169. # Approximation of exponential moving average of square of gradient
  170. update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row)
  171. update.mul_(grad)
  172. else:
  173. exp_avg_sq = state['exp_avg_sq']
  174. exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
  175. update = exp_avg_sq.rsqrt().mul_(grad)
  176. update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
  177. update.mul_(lr_t)
  178. if use_first_moment:
  179. exp_avg = state['exp_avg']
  180. exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
  181. if group['caution']:
  182. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  183. mask = (exp_avg * grad > 0).to(grad.dtype)
  184. mask.div_(mask.mean().clamp_(min=1e-3))
  185. update = exp_avg * mask
  186. else:
  187. update = exp_avg
  188. if group['weight_decay'] != 0:
  189. p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t)
  190. p_fp32.add_(-update)
  191. if p.dtype in {torch.float16, torch.bfloat16}:
  192. p.copy_(p_fp32)
  193. return loss