radam.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. # mypy: allow-untyped-defs
  2. r"""Implementation for the RAdam algorithm."""
  3. from typing import cast
  4. import torch
  5. from torch import Tensor
  6. from .optimizer import (
  7. _capturable_doc,
  8. _default_to_fused_or_foreach,
  9. _differentiable_doc,
  10. _disable_dynamo_if_unsupported,
  11. _foreach_doc,
  12. _get_capturable_supported_devices,
  13. _get_scalar_dtype,
  14. _get_value,
  15. _maximize_doc,
  16. _params_doc,
  17. _to_scalar,
  18. _use_grad_for_differentiable,
  19. _view_as_real,
  20. Optimizer,
  21. ParamsT,
  22. )
  23. __all__ = ["RAdam", "radam"]
  24. class RAdam(Optimizer): # noqa: D101
  25. def __init__(
  26. self,
  27. params: ParamsT,
  28. lr: float | Tensor = 1e-3,
  29. betas: tuple[float, float] = (0.9, 0.999),
  30. eps: float = 1e-8,
  31. weight_decay: float = 0,
  32. decoupled_weight_decay: bool = False,
  33. *,
  34. foreach: bool | None = None,
  35. maximize: bool = False,
  36. capturable: bool = False,
  37. differentiable: bool = False,
  38. ) -> None: # noqa: D107
  39. if isinstance(lr, Tensor) and lr.numel() != 1:
  40. raise ValueError("Tensor lr must be 1-element")
  41. if not 0.0 <= lr:
  42. raise ValueError(f"Invalid learning rate: {lr}")
  43. if not 0.0 <= eps:
  44. raise ValueError(f"Invalid epsilon value: {eps}")
  45. if not 0.0 <= betas[0] < 1.0:
  46. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  47. if not 0.0 <= betas[1] < 1.0:
  48. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  49. if not 0.0 <= weight_decay:
  50. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  51. defaults = {
  52. "lr": lr,
  53. "betas": betas,
  54. "eps": eps,
  55. "weight_decay": weight_decay,
  56. "maximize": maximize,
  57. "foreach": foreach,
  58. "capturable": capturable,
  59. "decoupled_weight_decay": decoupled_weight_decay,
  60. "differentiable": differentiable,
  61. }
  62. super().__init__(params, defaults)
  63. def __setstate__(self, state): # noqa: D105
  64. super().__setstate__(state)
  65. for group in self.param_groups:
  66. group.setdefault("foreach", None)
  67. group.setdefault("maximize", False)
  68. group.setdefault("differentiable", False)
  69. group.setdefault("decoupled_weight_decay", False)
  70. group.setdefault("capturable", False)
  71. for p in group["params"]:
  72. p_state = self.state.get(p, [])
  73. if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
  74. step_val = float(p_state["step"])
  75. p_state["step"] = (
  76. torch.tensor(
  77. step_val, dtype=_get_scalar_dtype(), device=p.device
  78. )
  79. if group["capturable"]
  80. else torch.tensor(step_val, dtype=_get_scalar_dtype())
  81. )
  82. def _init_group(
  83. self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
  84. ):
  85. has_complex = False
  86. for p in group["params"]:
  87. if p.grad is not None:
  88. has_complex |= torch.is_complex(p)
  89. params_with_grad.append(p)
  90. if p.grad.is_sparse:
  91. raise RuntimeError("RAdam does not support sparse gradients")
  92. grads.append(p.grad)
  93. state = self.state[p]
  94. # Lazy state initialization
  95. if len(state) == 0:
  96. state["step"] = (
  97. torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
  98. if group["capturable"]
  99. else torch.tensor(0.0, dtype=_get_scalar_dtype())
  100. )
  101. # Exponential moving average of gradient values
  102. state["exp_avg"] = torch.zeros_like(
  103. p, memory_format=torch.preserve_format
  104. )
  105. # Exponential moving average of squared gradient values
  106. state["exp_avg_sq"] = torch.zeros_like(
  107. p, memory_format=torch.preserve_format
  108. )
  109. exp_avgs.append(state["exp_avg"])
  110. exp_avg_sqs.append(state["exp_avg_sq"])
  111. state_steps.append(state["step"])
  112. return has_complex
  113. @_use_grad_for_differentiable
  114. def step(self, closure=None):
  115. """Perform a single optimization step.
  116. Args:
  117. closure (Callable, optional): A closure that reevaluates the model
  118. and returns the loss.
  119. """
  120. self._accelerator_graph_capture_health_check()
  121. loss = None
  122. if closure is not None:
  123. with torch.enable_grad():
  124. loss = closure()
  125. for group in self.param_groups:
  126. params_with_grad: list[Tensor] = []
  127. grads: list[Tensor] = []
  128. exp_avgs: list[Tensor] = []
  129. exp_avg_sqs: list[Tensor] = []
  130. state_steps: list[Tensor] = []
  131. beta1, beta2 = cast(tuple[float, float], group["betas"])
  132. has_complex = self._init_group(
  133. group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
  134. )
  135. radam(
  136. params_with_grad,
  137. grads,
  138. exp_avgs,
  139. exp_avg_sqs,
  140. state_steps,
  141. beta1=beta1,
  142. beta2=beta2,
  143. lr=group["lr"],
  144. weight_decay=group["weight_decay"],
  145. eps=group["eps"],
  146. maximize=group["maximize"],
  147. foreach=group["foreach"],
  148. capturable=group["capturable"],
  149. differentiable=group["differentiable"],
  150. decoupled_weight_decay=group["decoupled_weight_decay"],
  151. has_complex=has_complex,
  152. )
  153. return loss
  154. RAdam.__doc__ = (
  155. r"""Implements RAdam algorithm.
  156. .. math::
  157. \begin{aligned}
  158. &\rule{110mm}{0.4pt} \\
  159. &\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2
  160. \text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \:
  161. \lambda \text{ (weightdecay)}, \:\textit{maximize} \\
  162. &\hspace{13mm} \epsilon \text{ (epsilon)}, \textit{decoupled\_weight\_decay} \\
  163. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  164. v_0 \leftarrow 0 \text{ ( second moment)}, \\
  165. &\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex]
  166. &\rule{110mm}{0.4pt} \\
  167. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  168. &\hspace{6mm}\textbf{if} \: \textit{maximize}: \\
  169. &\hspace{12mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  170. &\hspace{6mm}\textbf{else} \\
  171. &\hspace{12mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  172. &\hspace{6mm} \theta_t \leftarrow \theta_{t-1} \\
  173. &\hspace{6mm} \textbf{if} \: \lambda \neq 0 \\
  174. &\hspace{12mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\
  175. &\hspace{18mm} \theta_t \leftarrow \theta_{t} - \gamma \lambda \theta_{t} \\
  176. &\hspace{12mm}\textbf{else} \\
  177. &\hspace{18mm} g_t \leftarrow g_t + \lambda \theta_{t} \\
  178. &\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  179. &\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  180. &\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
  181. &\hspace{6mm}\rho_t \leftarrow \rho_{\infty} -
  182. 2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex]
  183. &\hspace{6mm}\textbf{if} \: \rho_t > 5 \\
  184. &\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon } \\
  185. &\hspace{12mm} r_t \leftarrow
  186. \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\
  187. &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} r_t l_t \\
  188. &\hspace{6mm}\textbf{else} \\
  189. &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} \\
  190. &\rule{110mm}{0.4pt} \\[-1.ex]
  191. &\bf{return} \: \theta_t \\[-1.ex]
  192. &\rule{110mm}{0.4pt} \\[-1.ex]
  193. \end{aligned}
  194. For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_.
  195. This implementation provides an option to use either the original weight_decay implementation as in Adam
  196. (where the weight_decay is applied to the gradient) or the one from AdamW (where weight_decay is applied
  197. to the weight) through the decoupled_weight_decay option. When decoupled_weight_decay is set to False
  198. (default), it uses the original Adam style weight decay, otherwise, it uses the AdamW style which
  199. corresponds more closely to the `author's implementation`_ in the RAdam paper. Further information
  200. about decoupled weight decay can be found in `Decoupled Weight Decay Regularization`_.
  201. """
  202. + rf"""
  203. Args:
  204. {_params_doc}
  205. lr (float, Tensor, optional): learning rate (default: 1e-3)
  206. betas (Tuple[float, float], optional): coefficients used for computing
  207. running averages of gradient and its square (default: (0.9, 0.999))
  208. eps (float, optional): term added to the denominator to improve
  209. numerical stability (default: 1e-8)
  210. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  211. decoupled_weight_decay (bool, optional): whether to decouple the weight
  212. decay as in AdamW to obtain RAdamW. If True, the algorithm does not
  213. accumulate weight decay in the momentum nor variance. (default: False)
  214. {_foreach_doc}
  215. {_maximize_doc}
  216. {_capturable_doc}
  217. {_differentiable_doc}
  218. .. _On the variance of the adaptive learning rate and beyond:
  219. https://arxiv.org/abs/1908.03265
  220. .. _author's implementation:
  221. https://github.com/LiyuanLucasLiu/RAdam
  222. .. _Decoupled Weight Decay Regularization:
  223. https://arxiv.org/abs/1711.05101
  224. """
  225. )
  226. def _single_tensor_radam(
  227. params: list[Tensor],
  228. grads: list[Tensor],
  229. exp_avgs: list[Tensor],
  230. exp_avg_sqs: list[Tensor],
  231. state_steps: list[Tensor],
  232. *,
  233. beta1: float,
  234. beta2: float,
  235. lr: float,
  236. weight_decay: float,
  237. eps: float,
  238. decoupled_weight_decay: bool,
  239. differentiable: bool,
  240. maximize: bool,
  241. capturable: bool,
  242. has_complex: bool,
  243. ) -> None:
  244. if not torch.jit.is_scripting():
  245. lr = _to_scalar(lr)
  246. for i, param in enumerate(params):
  247. grad = grads[i] if not maximize else -grads[i]
  248. exp_avg = exp_avgs[i]
  249. exp_avg_sq = exp_avg_sqs[i]
  250. step_t = state_steps[i]
  251. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  252. if not torch.compiler.is_compiling() and capturable:
  253. capturable_supported_devices = _get_capturable_supported_devices()
  254. if not (
  255. param.device.type == step_t.device.type
  256. and param.device.type in capturable_supported_devices
  257. ):
  258. raise AssertionError(
  259. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  260. )
  261. if torch.is_complex(param):
  262. param = torch.view_as_real(param)
  263. grad = torch.view_as_real(grad)
  264. exp_avg = torch.view_as_real(exp_avg)
  265. exp_avg_sq = torch.view_as_real(exp_avg_sq)
  266. # update step
  267. step_t += 1
  268. step = step_t if capturable else _get_value(step_t)
  269. if weight_decay != 0:
  270. if decoupled_weight_decay:
  271. param.mul_(1 - lr * weight_decay)
  272. else:
  273. grad = grad.add(param, alpha=weight_decay)
  274. # Decay the first and second moment running average coefficient
  275. exp_avg.lerp_(grad, 1 - beta1)
  276. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
  277. bias_correction1 = 1 - beta1**step
  278. bias_correction2 = 1 - beta2**step
  279. # correcting bias for the first moving moment
  280. bias_corrected_exp_avg = exp_avg / bias_correction1
  281. # maximum length of the approximated SMA
  282. rho_inf = 2 / (1 - beta2) - 1
  283. # compute the length of the approximated SMA
  284. rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2
  285. def _compute_rect():
  286. # pyrefly: ignore [unsupported-operation]
  287. return (
  288. (rho_t - 4)
  289. * (rho_t - 2)
  290. * rho_inf
  291. / ((rho_inf - 4) * (rho_inf - 2) * rho_t)
  292. ) ** 0.5
  293. def _compute_adaptive_lr():
  294. exp_avg_sq_sqrt = exp_avg_sq.sqrt()
  295. if differentiable:
  296. exp_avg_sq_sqrt = exp_avg_sq_sqrt.add(eps)
  297. else:
  298. exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
  299. # pyrefly: ignore [unsupported-operation]
  300. return (bias_correction2**0.5) / exp_avg_sq_sqrt
  301. # Compute the variance rectification term and update parameters accordingly
  302. if capturable:
  303. update = torch.where(
  304. rho_t > 5.0, _compute_rect() * _compute_adaptive_lr(), 1.0
  305. )
  306. param.add_(bias_corrected_exp_avg * lr * update, alpha=-1.0)
  307. else:
  308. if rho_t > 5.0:
  309. param.add_(
  310. bias_corrected_exp_avg
  311. * lr
  312. * _compute_adaptive_lr()
  313. * _compute_rect(),
  314. alpha=-1.0,
  315. )
  316. else:
  317. param.add_(bias_corrected_exp_avg * lr, alpha=-1.0)
  318. def _multi_tensor_radam(
  319. params: list[Tensor],
  320. grads: list[Tensor],
  321. exp_avgs: list[Tensor],
  322. exp_avg_sqs: list[Tensor],
  323. state_steps: list[Tensor],
  324. *,
  325. beta1: float,
  326. beta2: float,
  327. lr: float,
  328. weight_decay: float,
  329. eps: float,
  330. decoupled_weight_decay: bool,
  331. differentiable: bool,
  332. maximize: bool,
  333. capturable: bool,
  334. has_complex: bool,
  335. ) -> None:
  336. if len(params) == 0:
  337. return
  338. if differentiable:
  339. raise AssertionError("_foreach ops don't support autograd")
  340. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  341. if not torch.compiler.is_compiling() and capturable:
  342. capturable_supported_devices = _get_capturable_supported_devices(
  343. supports_xla=False
  344. )
  345. if not all(
  346. p.device.type == step.device.type
  347. and p.device.type in capturable_supported_devices
  348. for p, step in zip(params, state_steps, strict=True)
  349. ):
  350. raise AssertionError(
  351. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  352. )
  353. lr = _to_scalar(lr)
  354. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  355. [params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
  356. )
  357. for (
  358. grouped_params_,
  359. grouped_grads_,
  360. grouped_exp_avgs_,
  361. grouped_exp_avg_sqs_,
  362. grouped_state_steps_,
  363. ), _ in grouped_tensors.values():
  364. grouped_params = cast(list[Tensor], grouped_params_)
  365. grouped_grads = cast(list[Tensor], grouped_grads_)
  366. grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
  367. grouped_exp_avg_sqs = cast(list[Tensor], grouped_exp_avg_sqs_)
  368. grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
  369. # Update steps
  370. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  371. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  372. # wrapped it once now. The alpha is required to assure we go to the right overload.
  373. if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
  374. torch._foreach_add_(
  375. grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  376. )
  377. else:
  378. torch._foreach_add_(grouped_state_steps, 1)
  379. if has_complex:
  380. _view_as_real(
  381. grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs
  382. )
  383. if maximize:
  384. grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
  385. # maximum length of the approximated SMA
  386. rho_inf = 2 / (1 - beta2) - 1
  387. # compute the length of the approximated SMA
  388. bias_correction1: tuple[Tensor, ...] | list[Tensor]
  389. bias_correction2: tuple[Tensor, ...] | list[Tensor]
  390. rho_t_list: tuple[Tensor, ...] | list[Tensor]
  391. if capturable:
  392. bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps)
  393. torch._foreach_neg_(bias_correction1)
  394. torch._foreach_add_(bias_correction1, 1)
  395. bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps)
  396. torch._foreach_mul_(bias_correction2, grouped_state_steps)
  397. torch._foreach_mul_(bias_correction2, 2)
  398. torch._foreach_div_(bias_correction2, bias_correction1)
  399. torch._foreach_neg_(bias_correction2)
  400. torch._foreach_add_(bias_correction2, rho_inf)
  401. rho_t_list = bias_correction2
  402. else:
  403. rho_t_list = [
  404. rho_inf
  405. - 2
  406. * _get_value(step)
  407. * (beta2 ** _get_value(step))
  408. / (1 - beta2 ** _get_value(step))
  409. for step in grouped_state_steps
  410. ]
  411. if weight_decay != 0:
  412. if decoupled_weight_decay:
  413. torch._foreach_mul_(grouped_params, 1 - lr * weight_decay)
  414. else:
  415. # Reuse the intermediate memory (grouped_grads) already allocated for maximize
  416. if maximize:
  417. torch._foreach_add_(
  418. grouped_grads, grouped_params, alpha=weight_decay
  419. )
  420. else:
  421. grouped_grads = torch._foreach_add( # type: ignore[assignment]
  422. grouped_grads, grouped_params, alpha=weight_decay
  423. )
  424. # Decay the first and second moment running average coefficient
  425. torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
  426. torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
  427. torch._foreach_addcmul_(
  428. grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2
  429. )
  430. # Delete the local intermediate since it won't be used anymore to save on peak memory
  431. del grouped_grads
  432. if capturable:
  433. num = torch._foreach_sub(rho_t_list, 4)
  434. sub2 = torch._foreach_sub(rho_t_list, 2)
  435. torch._foreach_mul_(num, sub2)
  436. del sub2
  437. torch._foreach_mul_(num, rho_inf)
  438. rho_inf = (rho_inf - 4) * (rho_inf - 2)
  439. denom = torch._foreach_mul(rho_t_list, rho_inf)
  440. torch._foreach_div_(num, denom)
  441. del denom
  442. torch._foreach_sqrt_(num)
  443. # TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884
  444. rect = [
  445. torch.where(rho_t > 5.0, n, 0.0)
  446. for n, rho_t in zip(num, rho_t_list, strict=True)
  447. ]
  448. del num
  449. del rho_t_list
  450. unrect_step_size = [torch.where(rect > 0, 0.0, 1.0) for rect in rect]
  451. torch._foreach_mul_(unrect_step_size, lr)
  452. bias_correction1 = torch._foreach_pow(beta1, grouped_state_steps)
  453. torch._foreach_neg_(bias_correction1)
  454. torch._foreach_add_(bias_correction1, 1)
  455. torch._foreach_div_(unrect_step_size, bias_correction1)
  456. torch._foreach_neg_(unrect_step_size)
  457. bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps)
  458. torch._foreach_neg_(bias_correction2)
  459. torch._foreach_add_(bias_correction2, 1)
  460. torch._foreach_sqrt_(bias_correction2)
  461. torch._foreach_mul_(bias_correction2, lr)
  462. torch._foreach_mul_(bias_correction2, rect)
  463. del rect
  464. torch._foreach_neg_(bias_correction2)
  465. torch._foreach_div_(bias_correction2, bias_correction1)
  466. del bias_correction1
  467. else:
  468. rect = [
  469. ( # type: ignore[misc]
  470. (rho_t - 4) # type: ignore[arg-type]
  471. * (rho_t - 2)
  472. * rho_inf
  473. / ((rho_inf - 4) * (rho_inf - 2) * rho_t)
  474. )
  475. ** 0.5
  476. if rho_t > 5
  477. else 0
  478. for rho_t in rho_t_list
  479. ]
  480. unrectified = [0 if rect > 0 else 1.0 for rect in rect]
  481. bias_correction1 = [
  482. 1 - beta1 ** _get_value(step) for step in grouped_state_steps
  483. ]
  484. unrect_step_size = [
  485. (lr * rect / bc) * -1
  486. for rect, bc in zip(unrectified, bias_correction1, strict=True)
  487. ]
  488. bias_correction2 = [
  489. ((1 - beta2 ** _get_value(step)) ** 0.5) * (lr * rect / bc) * -1
  490. for step, rect, bc in zip(
  491. grouped_state_steps, rect, bias_correction1, strict=True
  492. )
  493. ]
  494. buffer = torch._foreach_sqrt(grouped_exp_avg_sqs)
  495. torch._foreach_add_(buffer, eps)
  496. torch._foreach_div_(buffer, bias_correction2)
  497. torch._foreach_reciprocal_(buffer)
  498. torch._foreach_add_(buffer, unrect_step_size)
  499. # Here, buffer = sqrt(1 - beta2^t) * rect_step_size / (sqrt(v) + eps) + unrect_step_size
  500. torch._foreach_addcmul_(grouped_params, grouped_exp_avgs, buffer)
  501. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam)
  502. def radam(
  503. params: list[Tensor],
  504. grads: list[Tensor],
  505. exp_avgs: list[Tensor],
  506. exp_avg_sqs: list[Tensor],
  507. state_steps: list[Tensor],
  508. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  509. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  510. decoupled_weight_decay: bool = False,
  511. foreach: bool | None = None,
  512. differentiable: bool = False,
  513. capturable: bool = False,
  514. has_complex: bool = False,
  515. maximize: bool = False,
  516. *,
  517. beta1: float,
  518. beta2: float,
  519. lr: float,
  520. weight_decay: float,
  521. eps: float,
  522. ) -> None:
  523. r"""Functional API that performs RAdam algorithm computation.
  524. See :class:`~torch.optim.RAdam` for details.
  525. """
  526. if not all(isinstance(t, torch.Tensor) for t in state_steps):
  527. raise RuntimeError(
  528. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  529. )
  530. if foreach is None:
  531. _, foreach = _default_to_fused_or_foreach(
  532. params, differentiable, use_fused=False
  533. )
  534. if foreach and torch.jit.is_scripting():
  535. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  536. if foreach and not torch.jit.is_scripting():
  537. func = _multi_tensor_radam
  538. else:
  539. func = _single_tensor_radam
  540. func(
  541. params,
  542. grads,
  543. exp_avgs,
  544. exp_avg_sqs,
  545. state_steps,
  546. beta1=beta1,
  547. beta2=beta2,
  548. lr=lr,
  549. weight_decay=weight_decay,
  550. eps=eps,
  551. maximize=maximize,
  552. decoupled_weight_decay=decoupled_weight_decay,
  553. differentiable=differentiable,
  554. capturable=capturable,
  555. has_complex=has_complex,
  556. )