adan.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. """ Adan Optimizer
  2. Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
  3. https://arxiv.org/abs/2208.06677
  4. Implementation adapted from https://github.com/sail-sg/Adan
  5. """
  6. # Copyright 2022 Garena Online Private Limited
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. import math
  20. from typing import List, Optional, Tuple
  21. import torch
  22. from torch import Tensor
  23. from torch.optim.optimizer import Optimizer
  24. class MultiTensorApply(object):
  25. available = False
  26. warned = False
  27. def __init__(self, chunk_size):
  28. try:
  29. MultiTensorApply.available = True
  30. self.chunk_size = chunk_size
  31. except ImportError as err:
  32. MultiTensorApply.available = False
  33. MultiTensorApply.import_err = err
  34. def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
  35. return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
  36. class Adan(Optimizer):
  37. """ Implements a pytorch variant of Adan.
  38. Adan was proposed in Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models
  39. https://arxiv.org/abs/2208.06677
  40. Arguments:
  41. params: Iterable of parameters to optimize or dicts defining parameter groups.
  42. lr: Learning rate.
  43. betas: Coefficients used for first- and second-order moments.
  44. eps: Term added to the denominator to improve numerical stability.
  45. weight_decay: Decoupled weight decay (L2 penalty)
  46. no_prox: How to perform the weight decay
  47. caution: Enable caution from 'Cautious Optimizers'
  48. foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory.
  49. """
  50. def __init__(self,
  51. params,
  52. lr: float = 1e-3,
  53. betas: Tuple[float, float, float] = (0.98, 0.92, 0.99),
  54. eps: float = 1e-8,
  55. weight_decay: float = 0.0,
  56. no_prox: bool = False,
  57. caution: bool = False,
  58. foreach: Optional[bool] = None,
  59. ):
  60. if not 0.0 <= lr:
  61. raise ValueError('Invalid learning rate: {}'.format(lr))
  62. if not 0.0 <= eps:
  63. raise ValueError('Invalid epsilon value: {}'.format(eps))
  64. if not 0.0 <= betas[0] < 1.0:
  65. raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
  66. if not 0.0 <= betas[1] < 1.0:
  67. raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
  68. if not 0.0 <= betas[2] < 1.0:
  69. raise ValueError('Invalid beta parameter at index 2: {}'.format(betas[2]))
  70. defaults = dict(
  71. lr=lr,
  72. betas=betas,
  73. eps=eps,
  74. weight_decay=weight_decay,
  75. no_prox=no_prox,
  76. caution=caution,
  77. foreach=foreach,
  78. )
  79. super().__init__(params, defaults)
  80. def __setstate__(self, state):
  81. super(Adan, self).__setstate__(state)
  82. for group in self.param_groups:
  83. group.setdefault('no_prox', False)
  84. group.setdefault('caution', False)
  85. @torch.no_grad()
  86. def restart_opt(self):
  87. for group in self.param_groups:
  88. group['step'] = 0
  89. for p in group['params']:
  90. if p.requires_grad:
  91. state = self.state[p]
  92. # State initialization
  93. # Exponential moving average of gradient values
  94. state['exp_avg'] = torch.zeros_like(p)
  95. # Exponential moving average of squared gradient values
  96. state['exp_avg_sq'] = torch.zeros_like(p)
  97. # Exponential moving average of gradient difference
  98. state['exp_avg_diff'] = torch.zeros_like(p)
  99. @torch.no_grad()
  100. def step(self, closure=None):
  101. """Performs a single optimization step."""
  102. loss = None
  103. if closure is not None:
  104. with torch.enable_grad():
  105. loss = closure()
  106. try:
  107. has_scalar_maximum = 'Scalar' in torch.ops.aten._foreach_maximum_.overloads()
  108. except Exception:
  109. has_scalar_maximum = False
  110. for group in self.param_groups:
  111. params_with_grad = []
  112. grads = []
  113. exp_avgs = []
  114. exp_avg_sqs = []
  115. exp_avg_diffs = []
  116. neg_pre_grads = []
  117. beta1, beta2, beta3 = group['betas']
  118. # assume same step across group now to simplify things
  119. # per parameter step can be easily supported by making it a tensor, or pass list into kernel
  120. if 'step' in group:
  121. group['step'] += 1
  122. else:
  123. group['step'] = 1
  124. bias_correction1 = 1.0 - beta1 ** group['step']
  125. bias_correction2 = 1.0 - beta2 ** group['step']
  126. bias_correction3 = 1.0 - beta3 ** group['step']
  127. for p in group['params']:
  128. if p.grad is None:
  129. continue
  130. params_with_grad.append(p)
  131. grads.append(p.grad)
  132. state = self.state[p]
  133. if len(state) == 0:
  134. state['exp_avg'] = torch.zeros_like(p)
  135. state['exp_avg_sq'] = torch.zeros_like(p)
  136. state['exp_avg_diff'] = torch.zeros_like(p)
  137. if 'neg_pre_grad' not in state or group['step'] == 1:
  138. state['neg_pre_grad'] = -p.grad.clone()
  139. exp_avgs.append(state['exp_avg'])
  140. exp_avg_sqs.append(state['exp_avg_sq'])
  141. exp_avg_diffs.append(state['exp_avg_diff'])
  142. neg_pre_grads.append(state['neg_pre_grad'])
  143. if not params_with_grad:
  144. continue
  145. if group['foreach'] is None:
  146. use_foreach = not group['caution'] or has_scalar_maximum
  147. else:
  148. use_foreach = group['foreach']
  149. if use_foreach:
  150. func = _multi_tensor_adan
  151. else:
  152. func = _single_tensor_adan
  153. func(
  154. params_with_grad,
  155. grads,
  156. exp_avgs=exp_avgs,
  157. exp_avg_sqs=exp_avg_sqs,
  158. exp_avg_diffs=exp_avg_diffs,
  159. neg_pre_grads=neg_pre_grads,
  160. beta1=beta1,
  161. beta2=beta2,
  162. beta3=beta3,
  163. bias_correction1=bias_correction1,
  164. bias_correction2=bias_correction2,
  165. bias_correction3_sqrt=math.sqrt(bias_correction3),
  166. lr=group['lr'],
  167. weight_decay=group['weight_decay'],
  168. eps=group['eps'],
  169. no_prox=group['no_prox'],
  170. caution=group['caution'],
  171. )
  172. return loss
  173. def _single_tensor_adan(
  174. params: List[Tensor],
  175. grads: List[Tensor],
  176. exp_avgs: List[Tensor],
  177. exp_avg_sqs: List[Tensor],
  178. exp_avg_diffs: List[Tensor],
  179. neg_pre_grads: List[Tensor],
  180. *,
  181. beta1: float,
  182. beta2: float,
  183. beta3: float,
  184. bias_correction1: float,
  185. bias_correction2: float,
  186. bias_correction3_sqrt: float,
  187. lr: float,
  188. weight_decay: float,
  189. eps: float,
  190. no_prox: bool,
  191. caution: bool,
  192. ):
  193. for i, param in enumerate(params):
  194. grad = grads[i]
  195. exp_avg = exp_avgs[i]
  196. exp_avg_sq = exp_avg_sqs[i]
  197. exp_avg_diff = exp_avg_diffs[i]
  198. neg_grad_or_diff = neg_pre_grads[i]
  199. # for memory saving, we use `neg_grad_or_diff` to get some temp variable in an inplace way
  200. neg_grad_or_diff.add_(grad)
  201. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t
  202. exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff, alpha=1 - beta2) # diff_t
  203. neg_grad_or_diff.mul_(beta2).add_(grad)
  204. exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff, neg_grad_or_diff, value=1 - beta3) # n_t
  205. denom = (exp_avg_sq.sqrt() / bias_correction3_sqrt).add_(eps)
  206. step_size_diff = lr * beta2 / bias_correction2
  207. step_size = lr / bias_correction1
  208. if caution:
  209. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  210. mask = (exp_avg * grad > 0).to(grad.dtype)
  211. mask.div_(mask.mean().clamp_(min=1e-3))
  212. exp_avg = exp_avg * mask
  213. if no_prox:
  214. param.mul_(1 - lr * weight_decay)
  215. param.addcdiv_(exp_avg, denom, value=-step_size)
  216. param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)
  217. else:
  218. param.addcdiv_(exp_avg, denom, value=-step_size)
  219. param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)
  220. param.div_(1 + lr * weight_decay)
  221. neg_grad_or_diff.zero_().add_(grad, alpha=-1.0)
  222. def _multi_tensor_adan(
  223. params: List[Tensor],
  224. grads: List[Tensor],
  225. exp_avgs: List[Tensor],
  226. exp_avg_sqs: List[Tensor],
  227. exp_avg_diffs: List[Tensor],
  228. neg_pre_grads: List[Tensor],
  229. *,
  230. beta1: float,
  231. beta2: float,
  232. beta3: float,
  233. bias_correction1: float,
  234. bias_correction2: float,
  235. bias_correction3_sqrt: float,
  236. lr: float,
  237. weight_decay: float,
  238. eps: float,
  239. no_prox: bool,
  240. caution: bool,
  241. ):
  242. if len(params) == 0:
  243. return
  244. # for memory saving, we use `neg_pre_grads` to get some temp variable in a inplace way
  245. torch._foreach_add_(neg_pre_grads, grads)
  246. torch._foreach_mul_(exp_avgs, beta1)
  247. torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t
  248. torch._foreach_mul_(exp_avg_diffs, beta2)
  249. torch._foreach_add_(exp_avg_diffs, neg_pre_grads, alpha=1 - beta2) # diff_t
  250. torch._foreach_mul_(neg_pre_grads, beta2)
  251. torch._foreach_add_(neg_pre_grads, grads)
  252. torch._foreach_mul_(exp_avg_sqs, beta3)
  253. torch._foreach_addcmul_(exp_avg_sqs, neg_pre_grads, neg_pre_grads, value=1 - beta3) # n_t
  254. denom = torch._foreach_sqrt(exp_avg_sqs)
  255. torch._foreach_div_(denom, bias_correction3_sqrt)
  256. torch._foreach_add_(denom, eps)
  257. step_size_diff = lr * beta2 / bias_correction2
  258. step_size = lr / bias_correction1
  259. if caution:
  260. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  261. masks = torch._foreach_mul(exp_avgs, grads)
  262. masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
  263. mask_scale = [m.mean() for m in masks]
  264. torch._foreach_maximum_(mask_scale, 1e-3)
  265. torch._foreach_div_(masks, mask_scale)
  266. exp_avgs = torch._foreach_mul(exp_avgs, masks)
  267. if no_prox:
  268. torch._foreach_mul_(params, 1 - lr * weight_decay)
  269. torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
  270. torch._foreach_addcdiv_(params, exp_avg_diffs, denom, value=-step_size_diff)
  271. else:
  272. torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
  273. torch._foreach_addcdiv_(params, exp_avg_diffs, denom, value=-step_size_diff)
  274. torch._foreach_div_(params, 1 + lr * weight_decay)
  275. torch._foreach_zero_(neg_pre_grads)
  276. torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0)