adafactor_bv.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. """ Adafactor (Big Vision variant) for PyTorch
  2. Adapted from the implementation in big vision: https://github.com/google-research/big_vision
  3. Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560
  4. References for added functionality:
  5. Cautious Optimizers: https://arxiv.org/abs/2411.16085
  6. Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
  7. Adaptation and PyTorch modifications by Ross Wightman
  8. """
  9. from typing import List, Optional, Tuple, Union
  10. import torch
  11. from torch import Tensor
  12. from torch.optim import Optimizer
  13. from ._types import ParamsT
  14. def _get_scalar_dtype():
  15. """Get the scalar dtype that the optimizer uses for state"""
  16. return torch.float64
  17. def _factored_dims(
  18. shape: Tuple[int, ...],
  19. factored: bool,
  20. min_dim_size_to_factor: int
  21. ) -> Optional[tuple[int, int]]:
  22. """Whether to use a factored second moment estimator.
  23. This function returns a tuple with the two largest axes to reduce over.
  24. If no two dimensions have size >= min_dim_size_to_factor, return None.
  25. Args:
  26. shape: an input shape
  27. factored: whether to use factored second-moment estimator for > 2d vars.
  28. min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
  29. Returns:
  30. None or a tuple of ints
  31. """
  32. if not factored or len(shape) < 2:
  33. return None
  34. sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
  35. if shape[sorted_dims[-2][1]] < min_dim_size_to_factor:
  36. return None
  37. return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
  38. class AdafactorBigVision(Optimizer):
  39. """
  40. PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
  41. Adapted from https://github.com/google-research/big_vision by Ross Wightman
  42. """
  43. def __init__(
  44. self,
  45. params: ParamsT,
  46. lr: float = 1.0,
  47. min_dim_size_to_factor: int = 16,
  48. decay_rate: float = 0.8,
  49. decay_offset: int = 0,
  50. beta2_cap: float = 0.999,
  51. momentum: Optional[float] = 0.9,
  52. momentum_dtype: Union[str, torch.dtype] = torch.bfloat16,
  53. eps: Optional[float] = None,
  54. weight_decay: float = 0.0,
  55. clipping_threshold: Optional[float] = None,
  56. unscaled_wd: bool = False,
  57. caution: bool = False,
  58. corrected_weight_decay: bool = False,
  59. *,
  60. foreach: Optional[bool] = False,
  61. ):
  62. if isinstance(momentum_dtype, str):
  63. if momentum_dtype == 'float16':
  64. momentum_dtype = torch.float16
  65. elif momentum_dtype == 'bfloat16':
  66. momentum_dtype = torch.bfloat16
  67. else:
  68. assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported'
  69. momentum_dtype = torch.float32
  70. # FIXME try to check if momentum dtype is appropriate for device? Torch API not great for this.
  71. defaults = dict(
  72. lr=lr,
  73. min_dim_size_to_factor=min_dim_size_to_factor,
  74. decay_rate=decay_rate,
  75. decay_offset=decay_offset,
  76. beta2_cap=beta2_cap,
  77. momentum=momentum,
  78. momentum_dtype=momentum_dtype,
  79. eps=eps,
  80. weight_decay=weight_decay,
  81. clipping_threshold=clipping_threshold,
  82. unscaled_wd=unscaled_wd,
  83. caution=caution,
  84. corrected_weight_decay=corrected_weight_decay,
  85. foreach=foreach,
  86. )
  87. super().__init__(params, defaults)
  88. def __setstate__(self, state):
  89. super().__setstate__(state)
  90. for group in self.param_groups:
  91. group.setdefault('caution', False)
  92. group.setdefault('corrected_weight_decay', False)
  93. group.setdefault('foreach', None)
  94. for p in group['params']:
  95. p_state = self.state.get(p, {})
  96. if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
  97. p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
  98. if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']):
  99. # FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
  100. # the momentum to float32 (it's half precision in the state_dict), need to
  101. # look into this further. Better to override _process_value_according_to_param_policy?
  102. p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype'])
  103. @torch.no_grad()
  104. def step(self, closure=None):
  105. loss = None
  106. if closure is not None:
  107. with torch.enable_grad():
  108. loss = closure()
  109. for group in self.param_groups:
  110. params_with_grad = []
  111. grads = []
  112. exp_avg_sq_rs = []
  113. exp_avg_sq_cs = []
  114. exp_avg_sqs = []
  115. state_steps = []
  116. exp_avgs = [] # For momentum
  117. for p in group['params']:
  118. if p.grad is None:
  119. continue
  120. if p.grad.is_sparse:
  121. raise RuntimeError("Sparse gradients not supported")
  122. params_with_grad.append(p)
  123. grads.append(p.grad)
  124. state = self.state[p]
  125. if len(state) == 0:
  126. # NOTE step on CPU, probably need some more though to make capturable
  127. state['step'] = torch.tensor(0.0, dtype=_get_scalar_dtype())
  128. shape = p.grad.shape
  129. factored_dims = _factored_dims(
  130. shape,
  131. factored=True,
  132. min_dim_size_to_factor=self.defaults['min_dim_size_to_factor']
  133. )
  134. if factored_dims is not None:
  135. dc, dr = factored_dims
  136. row_shape = list(p.grad.shape)
  137. row_shape[dr] = 1
  138. col_shape = list(p.grad.shape)
  139. col_shape[dc] = 1
  140. state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
  141. state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
  142. else:
  143. state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
  144. if self.defaults['momentum'] is not None:
  145. state['exp_avg'] = torch.zeros_like(p.grad, dtype=self.defaults['momentum_dtype'])
  146. state_steps.append(state['step'])
  147. exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))
  148. exp_avg_sq_cs.append(state.get('exp_avg_sq_c', None))
  149. exp_avg_sqs.append(state.get('exp_avg_sq', None))
  150. exp_avgs.append(state.get('exp_avg', None))
  151. if group['foreach']:
  152. func = _multi_tensor_adafactor
  153. else:
  154. func = _single_tensor_adafactor
  155. func(
  156. params=params_with_grad,
  157. grads=grads,
  158. exp_avg_sq_rs=exp_avg_sq_rs,
  159. exp_avg_sq_cs=exp_avg_sq_cs,
  160. exp_avg_sqs=exp_avg_sqs,
  161. exp_avgs=exp_avgs,
  162. state_steps=state_steps,
  163. beta2_decay=group['decay_rate'],
  164. beta2_cap=group['beta2_cap'],
  165. min_dim_size_to_factor=group['min_dim_size_to_factor'],
  166. eps=group['eps'],
  167. lr=group['lr'],
  168. weight_decay=group['weight_decay'],
  169. momentum=group['momentum'],
  170. momentum_dtype=group['momentum_dtype'],
  171. clipping_threshold=group['clipping_threshold'],
  172. unscaled_wd=group['unscaled_wd'],
  173. caution=group['caution'],
  174. max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None,
  175. )
  176. return loss
  177. def _single_tensor_adafactor(
  178. params: List[Tensor],
  179. grads: List[Tensor],
  180. exp_avg_sq_rs: List[Optional[Tensor]],
  181. exp_avg_sq_cs: List[Optional[Tensor]],
  182. exp_avg_sqs: List[Optional[Tensor]],
  183. exp_avgs: List[Optional[Tensor]],
  184. state_steps: List[Tensor],
  185. *,
  186. beta2_decay: float,
  187. beta2_cap: float,
  188. min_dim_size_to_factor: int,
  189. eps: float,
  190. lr: float,
  191. weight_decay: float,
  192. momentum: Optional[float],
  193. momentum_dtype: Union[str, torch.dtype],
  194. clipping_threshold: Optional[float],
  195. unscaled_wd: bool,
  196. caution: bool,
  197. max_lr: Optional[float],
  198. ):
  199. for i, param in enumerate(params):
  200. grad = grads[i]
  201. exp_avg_sq_r = exp_avg_sq_rs[i]
  202. exp_avg_sq_c = exp_avg_sq_cs[i]
  203. exp_avg_sq = exp_avg_sqs[i]
  204. exp_avg = exp_avgs[i]
  205. step_t = state_steps[i]
  206. if eps is None:
  207. # default eps for avoiding div by zero, diff from float type eps
  208. eps = 1e-7 if grad.dtype == torch.float16 else 1e-30
  209. # Update step
  210. step_t += 1
  211. beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay))
  212. one_minus_beta2_t = 1 - beta2_t
  213. grad_sqr = torch.square(grad) + eps
  214. # NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach
  215. if exp_avg_sq is None:
  216. # factorized second moment
  217. dc, dr = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
  218. exp_avg_sq_r.lerp_(grad_sqr.mean(dim=dr, keepdim=True), one_minus_beta2_t)
  219. exp_avg_sq_c.lerp_(grad_sqr.mean(dim=dc, keepdim=True), one_minus_beta2_t)
  220. reduce_dc = dc - 1 if dc > dr else dc
  221. row_col_mean = exp_avg_sq_r.mean(dim=reduce_dc, keepdim=True)
  222. row_factor = (exp_avg_sq_r / row_col_mean).rsqrt()
  223. col_factor = exp_avg_sq_c.rsqrt()
  224. update = grad * row_factor * col_factor
  225. else:
  226. # non-factorized second moment
  227. assert exp_avg_sq_r is None and exp_avg_sq_c is None
  228. exp_avg_sq.lerp_(grad_sqr, one_minus_beta2_t)
  229. update = grad * exp_avg_sq.rsqrt()
  230. # Clip by RMS value
  231. if clipping_threshold is not None:
  232. denom = (update.norm(2) / ((update.numel() ** 0.5) / clipping_threshold)).clamp_(max=1.0)
  233. update.div_(denom)
  234. # Apply momentum (in different dtype)
  235. if momentum is not None and exp_avg is not None:
  236. if momentum_dtype != grad.dtype:
  237. exp_avg.lerp_(update.to(momentum_dtype), 1 - momentum) # ema
  238. update = exp_avg.to(grad.dtype)
  239. else:
  240. exp_avg.lerp_(update, 1 - momentum) # ema
  241. update = exp_avg.clone()
  242. if caution:
  243. # apply caution as per 'Cautious Optimizers': https://arxiv.org/abs/2411.16085
  244. mask = (update * grad > 0).to(grad.dtype)
  245. mask.div_(mask.mean().clamp_(min=1e-3))
  246. update.mul_(mask)
  247. # Scale by learning rate
  248. update.mul_(lr)
  249. # Perform weight decay
  250. if weight_decay != 0:
  251. if unscaled_wd:
  252. # match big vision impl, 'fully decoupled' decay w/o LR scaling
  253. if max_lr is None:
  254. param.mul_(1. - weight_decay)
  255. else:
  256. # corrected weight decay: scale by lr / max_lr
  257. param.mul_(1. - (lr / max_lr) * weight_decay)
  258. else:
  259. # match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR
  260. if max_lr is None:
  261. param.mul_(1. - lr * weight_decay)
  262. else:
  263. # corrected weight decay: scale by lr^2 / max_lr
  264. param.mul_(1. - (lr ** 2 / max_lr) * weight_decay)
  265. # Update parameters
  266. param.add_(update, alpha=-1.0)
  267. def _multi_tensor_adafactor(
  268. params: List[Tensor],
  269. grads: List[Tensor],
  270. exp_avg_sq_rs: List[Optional[Tensor]],
  271. exp_avg_sq_cs: List[Optional[Tensor]],
  272. exp_avg_sqs: List[Optional[Tensor]],
  273. exp_avgs: List[Optional[Tensor]],
  274. state_steps: List[Tensor],
  275. *,
  276. beta2_decay: float,
  277. beta2_cap: float,
  278. min_dim_size_to_factor: int,
  279. eps: float,
  280. lr: float,
  281. weight_decay: float,
  282. momentum: Optional[float],
  283. momentum_dtype: Union[str, torch.dtype],
  284. clipping_threshold: Optional[float],
  285. unscaled_wd: bool,
  286. caution: bool,
  287. max_lr: Optional[float],
  288. ):
  289. # FIXME TODO
  290. assert False, 'multi-tensor fn (foreach=True) not implemented yet'