adopt.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. """ ADOPT PyTorch Optimizer
  2. ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate: https://arxiv.org/abs/2411.02853
  3. Modified for reduced dependencies on PyTorch internals from original at: https://github.com/iShohei220/adopt
  4. @inproceedings{taniguchi2024adopt,
  5. author={Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka},
  6. booktitle = {Advances in Neural Information Processing Systems},
  7. title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate},
  8. year = {2024}
  9. }
  10. References for added functionality:
  11. Cautious Optimizers: https://arxiv.org/abs/2411.16085
  12. Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
  13. """
  14. from typing import cast, List, Optional, Tuple, Union
  15. import torch
  16. from torch import Tensor
  17. from torch.optim.optimizer import Optimizer
  18. from ._types import ParamsT
  19. __all__ = ["Adopt", "adopt"]
  20. def _view_as_real(params, *state_and_grads):
  21. for i, p in enumerate(params):
  22. if torch.is_complex(p):
  23. params[i] = torch.view_as_real(params[i])
  24. for s in state_and_grads:
  25. s[i] = torch.view_as_real(s[i])
  26. def _get_scalar_dtype(is_fused=None):
  27. if is_fused:
  28. return torch.float32
  29. return (
  30. torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32
  31. )
  32. def _is_compiling():
  33. if hasattr(torch, 'compiler') and hasattr(torch.compiler, 'is_compiling'):
  34. return torch.compiler.is_compiling()
  35. else:
  36. return False
  37. def _get_value(x):
  38. # item is significantly faster than a cpu tensor in eager mode
  39. if not torch.jit.is_scripting() and _is_compiling():
  40. return x
  41. else:
  42. return x.item() if isinstance(x, torch.Tensor) else x
  43. class Adopt(Optimizer):
  44. """
  45. ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate: https://arxiv.org/abs/2411.02853
  46. """
  47. def __init__(
  48. self,
  49. params: ParamsT,
  50. lr: Union[float, Tensor] = 1e-3,
  51. betas: Tuple[float, float] = (0.9, 0.9999),
  52. eps: float = 1e-6,
  53. clip_exp: Optional[float] = 0.333,
  54. weight_decay: float = 0.0,
  55. decoupled: bool = False,
  56. corrected_weight_decay: bool = False,
  57. *,
  58. caution: bool = False,
  59. foreach: Optional[bool] = False,
  60. maximize: bool = False,
  61. capturable: bool = False,
  62. differentiable: bool = False,
  63. ):
  64. if isinstance(lr, Tensor):
  65. if foreach and not capturable:
  66. raise ValueError(
  67. "lr as a Tensor is not supported for capturable=False and foreach=True"
  68. )
  69. if lr.numel() != 1:
  70. raise ValueError("Tensor lr must be 1-element")
  71. if not 0.0 <= lr:
  72. raise ValueError(f"Invalid learning rate: {lr}")
  73. if not 0.0 <= eps:
  74. raise ValueError(f"Invalid epsilon value: {eps}")
  75. if not 0.0 <= betas[0] < 1.0:
  76. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  77. if not 0.0 <= betas[1] < 1.0:
  78. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  79. if not 0.0 <= weight_decay:
  80. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  81. defaults = dict(
  82. lr=lr,
  83. betas=betas,
  84. eps=eps,
  85. weight_decay=weight_decay,
  86. clip_exp=clip_exp,
  87. decoupled=decoupled,
  88. corrected_weight_decay=corrected_weight_decay,
  89. caution=caution,
  90. maximize=maximize,
  91. foreach=foreach,
  92. capturable=capturable,
  93. differentiable=differentiable,
  94. )
  95. super().__init__(params, defaults)
  96. def __setstate__(self, state):
  97. super().__setstate__(state)
  98. for group in self.param_groups:
  99. group.setdefault("maximize", False)
  100. group.setdefault("foreach", None)
  101. group.setdefault("capturable", False)
  102. group.setdefault("differentiable", False)
  103. group.setdefault("clip_exp", None)
  104. group.setdefault("caution", False)
  105. group.setdefault("corrected_weight_decay", False)
  106. for p in group["params"]:
  107. p_state = self.state.get(p, [])
  108. if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
  109. step_val = float(p_state["step"])
  110. p_state["step"] = (
  111. torch.tensor(
  112. step_val,
  113. dtype=_get_scalar_dtype(),
  114. device=p.device,
  115. )
  116. if group["capturable"]
  117. else torch.tensor(step_val, dtype=_get_scalar_dtype())
  118. )
  119. def _init_group(
  120. self,
  121. group,
  122. params_with_grad,
  123. grads,
  124. exp_avgs,
  125. exp_avg_sqs,
  126. state_steps,
  127. ):
  128. has_complex = False
  129. for p in group["params"]:
  130. if p.grad is None:
  131. continue
  132. has_complex |= torch.is_complex(p)
  133. params_with_grad.append(p)
  134. if p.grad.is_sparse:
  135. raise RuntimeError("ADOPT does not support sparse gradients")
  136. grads.append(p.grad)
  137. state = self.state[p]
  138. # Lazy state initialization
  139. if len(state) == 0:
  140. # note(crcrpar): [special device hosting for step]
  141. # Deliberately host `step` on CPU if both capturable and fused are off.
  142. # This is because kernel launches are costly on CUDA and XLA.
  143. state["step"] = (
  144. torch.zeros((), dtype=_get_scalar_dtype(), device=p.grad.device)
  145. if group["capturable"]
  146. else torch.tensor(0.0, dtype=_get_scalar_dtype())
  147. )
  148. # Exponential moving average of gradient values
  149. state["exp_avg"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
  150. # Exponential moving average of squared gradient values
  151. state["exp_avg_sq"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
  152. exp_avgs.append(state["exp_avg"])
  153. exp_avg_sqs.append(state["exp_avg_sq"])
  154. if group["differentiable"] and state["step"].requires_grad:
  155. raise RuntimeError("`requires_grad` is not supported for `step` in differentiable mode")
  156. # Foreach without capturable does not support a tensor lr
  157. if group["foreach"] and torch.is_tensor(group["lr"]) and not group["capturable"]:
  158. raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True")
  159. state_steps.append(state["step"])
  160. return has_complex
  161. #@_use_grad_for_differentiable # FIXME internal context mgr, can't use
  162. @torch.no_grad()
  163. def step(self, closure=None):
  164. """Perform a single optimization step.
  165. Args:
  166. closure (Callable, optional): A closure that reevaluates the model
  167. and returns the loss.
  168. """
  169. self._cuda_graph_capture_health_check()
  170. loss = None
  171. if closure is not None:
  172. with torch.enable_grad():
  173. loss = closure()
  174. for group in self.param_groups:
  175. params_with_grad: List[Tensor] = []
  176. grads: List[Tensor] = []
  177. exp_avgs: List[Tensor] = []
  178. exp_avg_sqs: List[Tensor] = []
  179. state_steps: List[Tensor] = []
  180. beta1, beta2 = group["betas"]
  181. has_complex = self._init_group(
  182. group,
  183. params_with_grad,
  184. grads,
  185. exp_avgs,
  186. exp_avg_sqs,
  187. state_steps,
  188. )
  189. adopt(
  190. params_with_grad,
  191. grads,
  192. exp_avgs,
  193. exp_avg_sqs,
  194. state_steps,
  195. has_complex=has_complex,
  196. beta1=beta1,
  197. beta2=beta2,
  198. lr=group["lr"],
  199. weight_decay=group["weight_decay"],
  200. clip_exp=group["clip_exp"],
  201. max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None,
  202. decoupled=group["decoupled"],
  203. eps=group["eps"],
  204. caution=group["caution"],
  205. maximize=group["maximize"],
  206. foreach=group["foreach"],
  207. capturable=group["capturable"],
  208. differentiable=group["differentiable"],
  209. grad_scale=getattr(self, "grad_scale", None),
  210. found_inf=getattr(self, "found_inf", None),
  211. )
  212. return loss
  213. def _single_tensor_adopt(
  214. params: List[Tensor],
  215. grads: List[Tensor],
  216. exp_avgs: List[Tensor],
  217. exp_avg_sqs: List[Tensor],
  218. state_steps: List[Tensor],
  219. grad_scale: Optional[Tensor],
  220. found_inf: Optional[Tensor],
  221. *,
  222. has_complex: bool,
  223. beta1: float,
  224. beta2: float,
  225. lr: Union[float, Tensor],
  226. weight_decay: float,
  227. clip_exp: Optional[float],
  228. max_lr: Optional[float],
  229. decoupled: bool,
  230. eps: float,
  231. caution: bool,
  232. maximize: bool,
  233. capturable: bool,
  234. differentiable: bool,
  235. ):
  236. assert grad_scale is None and found_inf is None
  237. if torch.jit.is_scripting():
  238. # this assert is due to JIT being dumb and not realizing that the ops below
  239. # have overloads to handle both float and Tensor lrs, so we just assert it's
  240. # a float since most people using JIT are using floats
  241. assert isinstance(lr, float)
  242. for i, param in enumerate(params):
  243. grad = grads[i] if not maximize else -grads[i]
  244. exp_avg = exp_avgs[i]
  245. exp_avg_sq = exp_avg_sqs[i]
  246. step_t = state_steps[i]
  247. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  248. if capturable and not _is_compiling():
  249. from torch.optim.optimizer import _get_capturable_supported_devices
  250. capturable_supported_devices = _get_capturable_supported_devices()
  251. assert param.device.type == step_t.device.type and param.device.type in capturable_supported_devices,\
  252. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  253. # update step
  254. step_t += 1
  255. if torch.is_complex(param):
  256. grad = torch.view_as_real(grad)
  257. if exp_avg is not None:
  258. exp_avg = torch.view_as_real(exp_avg)
  259. if exp_avg_sq is not None:
  260. exp_avg_sq = torch.view_as_real(exp_avg_sq)
  261. param = torch.view_as_real(param)
  262. if weight_decay != 0 and not decoupled:
  263. grad = grad.add(param, alpha=weight_decay)
  264. step = step_t if capturable or differentiable else _get_value(step_t)
  265. if step == 1:
  266. exp_avg_sq.addcmul_(grad, grad.conj())
  267. continue
  268. if weight_decay != 0 and decoupled:
  269. wd_scale = lr ** 2 / max_lr if max_lr is not None else lr
  270. param.add_(param, alpha=-wd_scale * weight_decay)
  271. denom = torch.clamp(exp_avg_sq.sqrt(), eps)
  272. normed_grad = grad.div(denom)
  273. if clip_exp is not None:
  274. clip_val = (step - 1) ** clip_exp
  275. normed_grad.clamp_(-clip_val, clip_val)
  276. exp_avg.lerp_(normed_grad, 1 - beta1)
  277. if caution:
  278. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  279. mask = (exp_avg * grad > 0).to(grad.dtype)
  280. mask.div_(mask.mean().clamp_(min=1e-3))
  281. exp_avg = exp_avg * mask
  282. param.add_(exp_avg, alpha=-lr)
  283. exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
  284. def _multi_tensor_adopt(
  285. params: List[Tensor],
  286. grads: List[Tensor],
  287. exp_avgs: List[Tensor],
  288. exp_avg_sqs: List[Tensor],
  289. state_steps: List[Tensor],
  290. grad_scale: Optional[Tensor],
  291. found_inf: Optional[Tensor],
  292. *,
  293. has_complex: bool,
  294. beta1: float,
  295. beta2: float,
  296. lr: Union[float, Tensor],
  297. weight_decay: float,
  298. clip_exp: Optional[float],
  299. max_lr: Optional[float],
  300. decoupled: bool,
  301. eps: float,
  302. caution: bool,
  303. maximize: bool,
  304. capturable: bool,
  305. differentiable: bool,
  306. ):
  307. if len(params) == 0:
  308. return
  309. if isinstance(lr, Tensor) and not capturable:
  310. raise RuntimeError(
  311. "lr as a Tensor is not supported for capturable=False and foreach=True"
  312. )
  313. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  314. if capturable and not _is_compiling():
  315. from torch.optim.optimizer import _get_capturable_supported_devices
  316. capturable_supported_devices = _get_capturable_supported_devices(
  317. supports_xla=False
  318. )
  319. assert all(
  320. p.device.type == step.device.type and p.device.type in capturable_supported_devices
  321. for p, step in zip(params, state_steps)
  322. ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  323. assert grad_scale is None and found_inf is None
  324. assert not differentiable, "_foreach ops don't support autograd"
  325. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  326. [params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
  327. )
  328. for (
  329. device_params_,
  330. device_grads_,
  331. device_exp_avgs_,
  332. device_exp_avg_sqs_,
  333. device_state_steps_,
  334. ), _ in grouped_tensors.values():
  335. device_params = cast(List[Tensor], device_params_)
  336. device_grads = cast(List[Tensor], device_grads_)
  337. device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
  338. device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
  339. device_state_steps = cast(List[Tensor], device_state_steps_)
  340. # Handle complex parameters
  341. if has_complex:
  342. _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
  343. if maximize:
  344. device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
  345. # Update steps
  346. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  347. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  348. # wrapped it once now. The alpha is required to assure we go to the right overload.
  349. if not _is_compiling() and device_state_steps[0].is_cpu:
  350. torch._foreach_add_(device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0)
  351. else:
  352. torch._foreach_add_(device_state_steps, 1)
  353. if weight_decay != 0 and not decoupled:
  354. # Re-use the intermediate memory (device_grads) already allocated for maximize
  355. if maximize:
  356. torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
  357. else:
  358. device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
  359. if device_state_steps[0] == 1:
  360. torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
  361. continue
  362. if weight_decay != 0 and decoupled:
  363. wd_scale = lr ** 2 / max_lr if max_lr is not None else lr
  364. torch._foreach_add_(device_params, device_params, alpha=-wd_scale * weight_decay)
  365. exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
  366. torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
  367. normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
  368. if clip_exp is not None:
  369. clip_val = (device_state_steps[0] - 1) ** clip_exp
  370. torch._foreach_maximum_(normed_grad, -clip_val)
  371. torch._foreach_minimum_(normed_grad, clip_val)
  372. torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
  373. if caution:
  374. # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
  375. masks = torch._foreach_mul(device_exp_avgs, device_grads)
  376. masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)]
  377. mask_scale = [m.mean() for m in masks]
  378. torch._foreach_maximum_(mask_scale, 1e-3)
  379. torch._foreach_div_(masks, mask_scale)
  380. device_exp_avgs = torch._foreach_mul(device_exp_avgs, masks)
  381. torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
  382. torch._foreach_mul_(device_exp_avg_sqs, beta2)
  383. torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2)
  384. #@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use
  385. def adopt(
  386. params: List[Tensor],
  387. grads: List[Tensor],
  388. exp_avgs: List[Tensor],
  389. exp_avg_sqs: List[Tensor],
  390. state_steps: List[Tensor],
  391. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  392. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  393. foreach: Optional[bool] = None,
  394. capturable: bool = False,
  395. differentiable: bool = False,
  396. grad_scale: Optional[Tensor] = None,
  397. found_inf: Optional[Tensor] = None,
  398. has_complex: bool = False,
  399. *,
  400. beta1: float,
  401. beta2: float,
  402. lr: Union[float, Tensor],
  403. weight_decay: float,
  404. clip_exp: Optional[float],
  405. max_lr: Optional[float],
  406. decoupled: bool,
  407. eps: float,
  408. caution: bool,
  409. maximize: bool,
  410. ):
  411. r"""Functional API that performs ADOPT algorithm computation.
  412. """
  413. if foreach is None:
  414. foreach = False
  415. # this check is slow during compilation, so we skip it
  416. # if it's strictly needed we can add this check back in dynamo
  417. if not _is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
  418. raise RuntimeError(
  419. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  420. )
  421. if foreach and torch.jit.is_scripting():
  422. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  423. if foreach and not torch.jit.is_scripting():
  424. func = _multi_tensor_adopt
  425. else:
  426. func = _single_tensor_adopt
  427. func(
  428. params,
  429. grads,
  430. exp_avgs,
  431. exp_avg_sqs,
  432. state_steps,
  433. has_complex=has_complex,
  434. beta1=beta1,
  435. beta2=beta2,
  436. lr=lr,
  437. weight_decay=weight_decay,
  438. clip_exp=clip_exp,
  439. max_lr=max_lr,
  440. decoupled=decoupled,
  441. eps=eps,
  442. caution=caution,
  443. maximize=maximize,
  444. capturable=capturable,
  445. differentiable=differentiable,
  446. grad_scale=grad_scale,
  447. found_inf=found_inf,
  448. )