nadam.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. # mypy: allow-untyped-defs
  2. r"""Implementation for the NAdam algorithm."""
  3. from typing import cast
  4. import torch
  5. from torch import Tensor
  6. from .optimizer import (
  7. _capturable_doc,
  8. _default_to_fused_or_foreach,
  9. _differentiable_doc,
  10. _disable_dynamo_if_unsupported,
  11. _foreach_doc,
  12. _get_capturable_supported_devices,
  13. _get_scalar_dtype,
  14. _get_value,
  15. _maximize_doc,
  16. _params_doc,
  17. _stack_if_compiling,
  18. _to_scalar,
  19. _use_grad_for_differentiable,
  20. _view_as_real,
  21. Optimizer,
  22. ParamsT,
  23. )
  24. __all__ = ["NAdam", "nadam"]
  25. class NAdam(Optimizer): # noqa: D101
  26. def __init__(
  27. self,
  28. params: ParamsT,
  29. lr: float | Tensor = 2e-3,
  30. betas: tuple[float, float] = (0.9, 0.999),
  31. eps: float = 1e-8,
  32. weight_decay: float = 0,
  33. momentum_decay: float = 4e-3,
  34. decoupled_weight_decay: bool = False,
  35. *,
  36. foreach: bool | None = None,
  37. maximize: bool = False,
  38. capturable: bool = False,
  39. differentiable: bool = False,
  40. ) -> None: # noqa: D107
  41. if isinstance(lr, Tensor) and lr.numel() != 1:
  42. raise ValueError("Tensor lr must be 1-element")
  43. if not 0.0 <= lr:
  44. raise ValueError(f"Invalid learning rate: {lr}")
  45. if not 0.0 <= eps:
  46. raise ValueError(f"Invalid epsilon value: {eps}")
  47. if not 0.0 <= betas[0] < 1.0:
  48. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  49. if not 0.0 <= betas[1] < 1.0:
  50. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  51. if not 0.0 <= weight_decay:
  52. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  53. if not 0.0 <= momentum_decay:
  54. raise ValueError(f"Invalid momentum_decay value: {momentum_decay}")
  55. defaults = {
  56. "lr": lr,
  57. "betas": betas,
  58. "eps": eps,
  59. "weight_decay": weight_decay,
  60. "momentum_decay": momentum_decay,
  61. "decoupled_weight_decay": decoupled_weight_decay,
  62. "maximize": maximize,
  63. "foreach": foreach,
  64. "capturable": capturable,
  65. "differentiable": differentiable,
  66. }
  67. super().__init__(params, defaults)
  68. def __setstate__(self, state): # noqa: D105
  69. super().__setstate__(state)
  70. for group in self.param_groups:
  71. group.setdefault("maximize", False)
  72. group.setdefault("foreach", None)
  73. group.setdefault("capturable", False)
  74. group.setdefault("differentiable", False)
  75. group.setdefault("decoupled_weight_decay", False)
  76. for p in group["params"]:
  77. p_state = self.state.get(p, [])
  78. if len(p_state) != 0:
  79. if not torch.is_tensor(p_state["step"]):
  80. step_val = float(p_state["step"])
  81. p_state["step"] = (
  82. torch.tensor(
  83. step_val, dtype=_get_scalar_dtype(), device=p.device
  84. )
  85. if group["capturable"]
  86. else torch.tensor(step_val, dtype=_get_scalar_dtype())
  87. )
  88. if not torch.is_tensor(p_state["mu_product"]):
  89. mu_prod_val = p_state["mu_product"]
  90. p_state["mu_product"] = (
  91. torch.tensor(
  92. mu_prod_val, dtype=_get_scalar_dtype(), device=p.device
  93. )
  94. if group["capturable"]
  95. else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype())
  96. )
  97. def _init_group(
  98. self,
  99. group,
  100. params_with_grad,
  101. grads,
  102. exp_avgs,
  103. exp_avg_sqs,
  104. mu_products,
  105. state_steps,
  106. ):
  107. has_complex = False
  108. for p in group["params"]:
  109. if p.grad is not None:
  110. has_complex |= torch.is_complex(p)
  111. params_with_grad.append(p)
  112. if p.grad.is_sparse:
  113. raise RuntimeError("NAdam does not support sparse gradients")
  114. grads.append(p.grad)
  115. state = self.state[p]
  116. # Lazy state initialization
  117. if len(state) == 0:
  118. # note(crcrpar): [special device hosting for step]
  119. # Deliberately host `step` and `mu_product` on CPU if capturable is False.
  120. # This is because kernel launches are costly on CUDA and XLA.
  121. state["step"] = (
  122. torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
  123. if group["capturable"]
  124. else torch.tensor(0.0, dtype=_get_scalar_dtype())
  125. )
  126. state["mu_product"] = (
  127. torch.ones((), dtype=_get_scalar_dtype(), device=p.device)
  128. if group["capturable"]
  129. else torch.tensor(1.0, dtype=_get_scalar_dtype())
  130. )
  131. # Exponential moving average of gradient values
  132. state["exp_avg"] = torch.zeros_like(
  133. p, memory_format=torch.preserve_format
  134. )
  135. # Exponential moving average of squared gradient values
  136. state["exp_avg_sq"] = torch.zeros_like(
  137. p, memory_format=torch.preserve_format
  138. )
  139. exp_avgs.append(state["exp_avg"])
  140. exp_avg_sqs.append(state["exp_avg_sq"])
  141. mu_products.append(state["mu_product"])
  142. state_steps.append(state["step"])
  143. return has_complex
  144. @_use_grad_for_differentiable
  145. def step(self, closure=None):
  146. """Perform a single optimization step.
  147. Args:
  148. closure (Callable, optional): A closure that reevaluates the model
  149. and returns the loss.
  150. """
  151. self._accelerator_graph_capture_health_check()
  152. loss = None
  153. if closure is not None:
  154. with torch.enable_grad():
  155. loss = closure()
  156. for group in self.param_groups:
  157. params_with_grad: list[Tensor] = []
  158. grads: list[Tensor] = []
  159. exp_avgs: list[Tensor] = []
  160. exp_avg_sqs: list[Tensor] = []
  161. mu_products: list[Tensor] = []
  162. state_steps: list[Tensor] = []
  163. beta1, beta2 = cast(tuple[float, float], group["betas"])
  164. has_complex = self._init_group(
  165. group,
  166. params_with_grad,
  167. grads,
  168. exp_avgs,
  169. exp_avg_sqs,
  170. mu_products,
  171. state_steps,
  172. )
  173. nadam(
  174. params_with_grad,
  175. grads,
  176. exp_avgs,
  177. exp_avg_sqs,
  178. mu_products,
  179. state_steps,
  180. beta1=beta1,
  181. beta2=beta2,
  182. lr=group["lr"],
  183. weight_decay=group["weight_decay"],
  184. momentum_decay=group["momentum_decay"],
  185. eps=group["eps"],
  186. maximize=group["maximize"],
  187. decoupled_weight_decay=group["decoupled_weight_decay"],
  188. foreach=group["foreach"],
  189. capturable=group["capturable"],
  190. differentiable=group["differentiable"],
  191. has_complex=has_complex,
  192. )
  193. return loss
  194. NAdam.__doc__ = (
  195. r"""Implements NAdam algorithm.
  196. .. math::
  197. \begin{aligned}
  198. &\rule{110mm}{0.4pt} \\
  199. &\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)},
  200. \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
  201. &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\
  202. &\hspace{13mm} \: \textit{decoupled\_weight\_decay}, \:\textit{maximize} \\
  203. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  204. v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex]
  205. &\rule{110mm}{0.4pt} \\
  206. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  207. &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
  208. &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  209. &\hspace{5mm}\textbf{else} \\
  210. &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  211. &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} \\
  212. &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
  213. &\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\
  214. &\hspace{15mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
  215. &\hspace{10mm}\textbf{else} \\
  216. &\hspace{15mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  217. &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\
  218. &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\
  219. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  220. &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  221. &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex]
  222. & \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\
  223. &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
  224. &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
  225. \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
  226. &\rule{110mm}{0.4pt} \\[-1.ex]
  227. &\bf{return} \: \theta_t \\[-1.ex]
  228. &\rule{110mm}{0.4pt} \\[-1.ex]
  229. \end{aligned}
  230. For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_.
  231. """
  232. + rf"""
  233. Args:
  234. {_params_doc}
  235. lr (float, Tensor, optional): learning rate (default: 2e-3)
  236. betas (Tuple[float, float], optional): coefficients used for computing
  237. running averages of gradient and its square (default: (0.9, 0.999))
  238. eps (float, optional): term added to the denominator to improve
  239. numerical stability (default: 1e-8)
  240. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  241. momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
  242. decoupled_weight_decay (bool, optional): whether to decouple the weight
  243. decay as in AdamW to obtain NAdamW. If True, the algorithm does not
  244. accumulate weight decay in the momentum nor variance. (default: False)
  245. {_foreach_doc}
  246. {_maximize_doc}
  247. {_capturable_doc}
  248. {_differentiable_doc}
  249. .. _Incorporating Nesterov Momentum into Adam:
  250. https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
  251. .. _Decoupled Weight Decay Regularization:
  252. https://arxiv.org/abs/1711.05101
  253. """
  254. )
  255. def _single_tensor_nadam(
  256. params: list[Tensor],
  257. grads: list[Tensor],
  258. exp_avgs: list[Tensor],
  259. exp_avg_sqs: list[Tensor],
  260. mu_products: list[Tensor],
  261. state_steps: list[Tensor],
  262. *,
  263. beta1: float,
  264. beta2: float,
  265. lr: float,
  266. weight_decay: float,
  267. momentum_decay: float,
  268. eps: float,
  269. decoupled_weight_decay: bool,
  270. maximize: bool,
  271. capturable: bool,
  272. differentiable: bool,
  273. has_complex: bool,
  274. ) -> None:
  275. if not torch.jit.is_scripting():
  276. lr = _to_scalar(lr)
  277. for i, param in enumerate(params):
  278. grad = grads[i] if not maximize else -grads[i]
  279. exp_avg = exp_avgs[i]
  280. exp_avg_sq = exp_avg_sqs[i]
  281. mu_product = mu_products[i]
  282. step_t = state_steps[i]
  283. if torch.is_complex(param):
  284. param = torch.view_as_real(param)
  285. grad = torch.view_as_real(grad)
  286. exp_avg = torch.view_as_real(exp_avg)
  287. exp_avg_sq = torch.view_as_real(exp_avg_sq)
  288. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  289. if not torch.compiler.is_compiling() and capturable:
  290. capturable_supported_devices = _get_capturable_supported_devices()
  291. if not (
  292. param.device.type == mu_product.device.type == step_t.device.type
  293. and param.device.type in capturable_supported_devices
  294. ):
  295. raise AssertionError(
  296. f"If capturable=True, params, mu_products and state_steps must be "
  297. f"on supported devices: {capturable_supported_devices}."
  298. )
  299. # update step
  300. step_t += 1
  301. if capturable:
  302. step = step_t
  303. else:
  304. step = _get_value(step_t)
  305. bias_correction2 = 1 - beta2**step
  306. if weight_decay != 0:
  307. if decoupled_weight_decay:
  308. # Perform stepweight decay
  309. param.mul_(1 - lr * weight_decay)
  310. else:
  311. grad = grad.add(param, alpha=weight_decay)
  312. # calculate the momentum cache \mu^{t} and \mu^{t+1}
  313. mu = beta1 * (1.0 - 0.5 * (0.96 ** (step * momentum_decay)))
  314. mu_next = beta1 * (1.0 - 0.5 * (0.96 ** ((step + 1) * momentum_decay)))
  315. # update mu_product
  316. mu_product *= mu
  317. # decay the first and second moment running average coefficient
  318. exp_avg.lerp_(grad, 1 - beta1)
  319. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
  320. denom = exp_avg_sq.div(bias_correction2).sqrt()
  321. if differentiable or capturable:
  322. denom = denom.add(eps)
  323. # Make autograd track the operations
  324. # by updating the grad and exp_avg directly and not using the
  325. # scalar "value" argument of addcdiv.
  326. mu_product_next = mu_product * mu_next
  327. grad = grad * (-lr * (1.0 - mu) / (1.0 - mu_product))
  328. exp_avg = exp_avg * (-lr * mu_next / (1.0 - mu_product_next))
  329. param.addcdiv_(grad, denom)
  330. param.addcdiv_(exp_avg, denom)
  331. else:
  332. mu_product_next = _get_value(mu_product) * mu_next
  333. denom.add_(eps)
  334. param.addcdiv_(
  335. grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product)))
  336. )
  337. param.addcdiv_(
  338. exp_avg,
  339. denom,
  340. value=cast(float, (-lr * mu_next) / (1.0 - mu_product_next)),
  341. )
  342. def _multi_tensor_nadam(
  343. params: list[Tensor],
  344. grads: list[Tensor],
  345. exp_avgs: list[Tensor],
  346. exp_avg_sqs: list[Tensor],
  347. mu_products: list[Tensor],
  348. state_steps: list[Tensor],
  349. *,
  350. beta1: float,
  351. beta2: float,
  352. lr: float,
  353. weight_decay: float,
  354. momentum_decay: float,
  355. eps: float,
  356. decoupled_weight_decay: bool,
  357. maximize: bool,
  358. capturable: bool,
  359. differentiable: bool,
  360. has_complex: bool,
  361. ) -> None:
  362. if len(params) == 0:
  363. return
  364. if differentiable:
  365. raise AssertionError("_foreach ops don't support autograd")
  366. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  367. if not torch.compiler.is_compiling() and capturable:
  368. capturable_supported_devices = _get_capturable_supported_devices(
  369. supports_xla=False
  370. )
  371. if not all(
  372. p.device.type == mp.device.type == step.device.type
  373. and p.device.type in capturable_supported_devices
  374. for p, mp, step in zip(params, mu_products, state_steps, strict=True)
  375. ):
  376. raise AssertionError(
  377. "If capturable=True, "
  378. "params, mu_products, and state_steps must be on supported devices: "
  379. f"{capturable_supported_devices}."
  380. )
  381. lr = _to_scalar(lr)
  382. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  383. [params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps] # type: ignore[list-item]
  384. )
  385. for (
  386. grouped_params_,
  387. grouped_grads_,
  388. grouped_exp_avgs_,
  389. grouped_exp_avg_sqs_,
  390. grouped_mu_products_,
  391. grouped_state_steps_,
  392. ), _ in grouped_tensors.values():
  393. grouped_params = cast(list[Tensor], grouped_params_)
  394. grouped_grads = cast(list[Tensor], grouped_grads_)
  395. grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
  396. grouped_exp_avg_sqs = cast(list[Tensor], grouped_exp_avg_sqs_)
  397. grouped_mu_products = cast(list[Tensor], grouped_mu_products_)
  398. grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
  399. # handle complex
  400. if has_complex:
  401. _view_as_real(
  402. grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs
  403. )
  404. if maximize:
  405. grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
  406. # Update steps
  407. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  408. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  409. # wrapped it once now. The alpha is required to assure we go to the right overload.
  410. if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
  411. torch._foreach_add_(
  412. grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  413. )
  414. else:
  415. torch._foreach_add_(grouped_state_steps, 1)
  416. if weight_decay != 0:
  417. if decoupled_weight_decay:
  418. # Perform stepweight decay
  419. torch._foreach_mul_(grouped_params, 1 - lr * weight_decay)
  420. else:
  421. # Reuse the intermediate memory (grouped_grads) already allocated for maximize
  422. if maximize:
  423. torch._foreach_add_(
  424. grouped_grads, grouped_params, alpha=weight_decay
  425. )
  426. else:
  427. grouped_grads = torch._foreach_add( # type: ignore[assignment]
  428. grouped_grads, grouped_params, alpha=weight_decay
  429. )
  430. # Decay the first and second moment running average coefficient
  431. torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
  432. torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
  433. torch._foreach_addcmul_(
  434. grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2
  435. )
  436. exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs)
  437. bias_correction_sqrt: tuple[Tensor, ...] | list[Tensor]
  438. mus: tuple[Tensor, ...] | list[Tensor]
  439. mu_nexts: tuple[Tensor, ...] | list[Tensor]
  440. if capturable:
  441. # mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay))
  442. exponent = torch._foreach_mul(grouped_state_steps, momentum_decay)
  443. mus = torch._foreach_pow(0.96, exponent)
  444. torch._foreach_mul_(mus, -0.5)
  445. torch._foreach_add_(mus, 1.0)
  446. torch._foreach_mul_(mus, beta1)
  447. # mu_nexts will be beta1 * (1 - 0.5 * 0.96 ** ((step + 1) * momentum_decay))
  448. torch._foreach_add_(exponent, momentum_decay)
  449. mu_nexts = torch._foreach_pow(0.96, exponent)
  450. torch._foreach_mul_(mu_nexts, -0.5)
  451. torch._foreach_add_(mu_nexts, 1.0)
  452. torch._foreach_mul_(mu_nexts, beta1)
  453. # save peak memory as we don't need exponent anymore
  454. del exponent
  455. bias_correction_sqrt = torch._foreach_pow(beta2, grouped_state_steps)
  456. # foreach_sub doesn't allow a scalar as the first arg
  457. torch._foreach_sub_(bias_correction_sqrt, 1.0)
  458. torch._foreach_neg_(bias_correction_sqrt)
  459. torch._foreach_sqrt_(bias_correction_sqrt)
  460. else:
  461. bias_correction_sqrt = [
  462. (1 - beta2 ** _get_value(step)) ** 0.5 for step in grouped_state_steps
  463. ]
  464. mus = [
  465. beta1 * (1.0 - 0.5 * (0.96 ** (_get_value(step) * momentum_decay)))
  466. for step in grouped_state_steps
  467. ]
  468. mu_nexts = [
  469. beta1
  470. * (1.0 - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay)))
  471. for step in grouped_state_steps
  472. ]
  473. # update mu_products
  474. torch._foreach_mul_(grouped_mu_products, mus)
  475. torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt)
  476. torch._foreach_add_(exp_avg_sq_sqrt, eps)
  477. # explicitly delete bias_correction refs to save memory
  478. del bias_correction_sqrt
  479. if capturable:
  480. # Build up the step_size multiplier for grad, reusing mus' memory
  481. torch._foreach_sub_(mus, 1.0)
  482. torch._foreach_mul_(mus, lr)
  483. # foreach_sub doesn't allow a scalar as the first arg
  484. denom = torch._foreach_sub(grouped_mu_products, 1.0)
  485. torch._foreach_neg_(denom)
  486. torch._foreach_div_(mus, denom)
  487. # - lr * (1 - mu) / (1 - mu_product)
  488. step_size_grads = mus
  489. # explicitly delete denom to save memory
  490. del denom
  491. # Build up the step_size multiplier for exp_avg, reusing mu_nexts' memory
  492. denom = torch._foreach_mul(grouped_mu_products, mu_nexts)
  493. torch._foreach_mul_(mu_nexts, lr)
  494. # foreach_sub doesn't allow a scalar as the first arg, but it's okay because
  495. # we need a negative here anyway
  496. torch._foreach_sub_(denom, 1.0)
  497. torch._foreach_div_(mu_nexts, denom)
  498. # - lr * mu_next / (1 - mu_product * mu_next)
  499. step_size_expavg = mu_nexts
  500. # explicitly delete denom to save memory
  501. del denom
  502. # we cannot inplace into step_size_grads cuz it is a list of ScalarTensors
  503. # and mul'ing with grouped_grads will result in a list of bigger Tensors
  504. numerator = torch._foreach_mul(step_size_grads, grouped_grads)
  505. torch._foreach_addcmul_(numerator, step_size_expavg, grouped_exp_avgs)
  506. # finally, update params
  507. torch._foreach_addcdiv_(grouped_params, numerator, exp_avg_sq_sqrt)
  508. else:
  509. step_size_grads = _stack_if_compiling(
  510. [
  511. (_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1
  512. for mu_product, mu in zip(grouped_mu_products, mus, strict=True)
  513. ]
  514. )
  515. step_size_expavg = _stack_if_compiling(
  516. [
  517. (
  518. _get_value(lr)
  519. * mu_next
  520. / (1.0 - _get_value(mu_product) * mu_next)
  521. )
  522. * -1
  523. for mu_product, mu_next in zip(
  524. grouped_mu_products, mu_nexts, strict=True
  525. )
  526. ]
  527. )
  528. torch._foreach_addcdiv_(
  529. grouped_params,
  530. grouped_grads,
  531. exp_avg_sq_sqrt,
  532. step_size_grads, # type: ignore[arg-type]
  533. )
  534. torch._foreach_addcdiv_(
  535. grouped_params,
  536. grouped_exp_avgs,
  537. exp_avg_sq_sqrt,
  538. step_size_expavg, # type: ignore[arg-type]
  539. )
  540. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam)
  541. def nadam(
  542. params: list[Tensor],
  543. grads: list[Tensor],
  544. exp_avgs: list[Tensor],
  545. exp_avg_sqs: list[Tensor],
  546. mu_products: list[Tensor],
  547. state_steps: list[Tensor],
  548. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  549. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  550. decoupled_weight_decay: bool = False,
  551. foreach: bool | None = None,
  552. capturable: bool = False,
  553. differentiable: bool = False,
  554. has_complex: bool = False,
  555. maximize: bool = False,
  556. *,
  557. beta1: float,
  558. beta2: float,
  559. lr: float,
  560. weight_decay: float,
  561. momentum_decay: float,
  562. eps: float,
  563. ) -> None:
  564. r"""Functional API that performs NAdam algorithm computation.
  565. See :class:`~torch.optim.NAdam` for details.
  566. """
  567. if not all(isinstance(t, torch.Tensor) for t in state_steps):
  568. raise RuntimeError(
  569. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  570. )
  571. if not all(isinstance(t, torch.Tensor) for t in mu_products):
  572. raise RuntimeError(
  573. "API has changed, `mu_products` argument must contain a list of singleton tensors"
  574. )
  575. if foreach is None:
  576. _, foreach = _default_to_fused_or_foreach(
  577. params, differentiable, use_fused=False
  578. )
  579. if foreach and torch.jit.is_scripting():
  580. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  581. if foreach and not torch.jit.is_scripting():
  582. func = _multi_tensor_nadam
  583. else:
  584. func = _single_tensor_nadam
  585. func(
  586. params,
  587. grads,
  588. exp_avgs,
  589. exp_avg_sqs,
  590. mu_products,
  591. state_steps,
  592. beta1=beta1,
  593. beta2=beta2,
  594. lr=lr,
  595. weight_decay=weight_decay,
  596. momentum_decay=momentum_decay,
  597. maximize=maximize,
  598. decoupled_weight_decay=decoupled_weight_decay,
  599. eps=eps,
  600. capturable=capturable,
  601. differentiable=differentiable,
  602. has_complex=has_complex,
  603. )