_adafactor.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. from typing import cast, TYPE_CHECKING
  4. import torch
  5. from torch import Tensor
  6. from .optimizer import (
  7. _disable_dynamo_if_unsupported,
  8. _get_scalar_dtype,
  9. _maximize_doc,
  10. _params_doc,
  11. _to_scalar,
  12. Optimizer,
  13. ParamsT,
  14. TensorListList,
  15. )
  16. __all__ = ["Adafactor", "adafactor"]
  17. class Adafactor(Optimizer):
  18. def __init__(
  19. self,
  20. params: ParamsT,
  21. lr: float | Tensor = 1e-2,
  22. beta2_decay: float = -0.8,
  23. eps: tuple[float | None, float] = (None, 1e-3),
  24. d: float = 1.0,
  25. weight_decay: float = 0.0,
  26. *,
  27. foreach: bool | None = None,
  28. maximize: bool = False,
  29. ) -> None:
  30. if isinstance(lr, Tensor) and lr.numel() != 1:
  31. raise ValueError("Tensor lr must be 1-element")
  32. if not 0.0 <= lr:
  33. raise ValueError(f"Learning rate should be >= 0 but is: {lr}")
  34. if not 0.0 >= beta2_decay:
  35. raise ValueError(f"beta2_decay should be <= 0 but is: {beta2_decay}")
  36. if eps[0] is not None and not 0.0 <= eps[0]:
  37. raise ValueError(f"epsilon1 should be >= 0 but is: {eps[0]}")
  38. if not 0.0 <= eps[1]:
  39. raise ValueError(f"epsilon2 should be >= 0 but is: {eps[1]}")
  40. if not 1.0 <= d:
  41. raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}")
  42. if not 0.0 <= weight_decay:
  43. raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}")
  44. defaults = {
  45. "lr": lr,
  46. "beta2_decay": beta2_decay,
  47. "eps": eps,
  48. "d": d,
  49. "weight_decay": weight_decay,
  50. "foreach": foreach,
  51. "maximize": maximize,
  52. }
  53. super().__init__(params, defaults)
  54. def __setstate__(self, state):
  55. super().__setstate__(state)
  56. for group in self.param_groups:
  57. group.setdefault("foreach", None)
  58. for p in group["params"]:
  59. p_state = self.state.get(p, [])
  60. if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
  61. step_val = float(p_state["step"])
  62. p_state["step"] = torch.tensor(step_val, dtype=_get_scalar_dtype())
  63. def _init_group(
  64. self,
  65. group,
  66. params_with_grad,
  67. grads,
  68. row_vars,
  69. col_vars,
  70. variances,
  71. state_steps,
  72. ) -> bool:
  73. for p in group["params"]:
  74. if p.grad is None:
  75. continue
  76. if torch.is_complex(p):
  77. raise RuntimeError("Adafactor does not support complex parameters")
  78. if p.grad.is_sparse:
  79. raise RuntimeError("Adafactor does not support sparse gradients")
  80. params_with_grad.append(p)
  81. grads.append(p.grad)
  82. state = self.state[p]
  83. # State initialization
  84. if len(state) == 0:
  85. # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
  86. # This is because kernel launches are costly on CUDA and XLA.
  87. state["step"] = torch.tensor(0.0, dtype=_get_scalar_dtype())
  88. if p.grad.dim() > 1:
  89. row_shape = list(p.grad.shape)
  90. row_shape[-1] = 1
  91. # Row factor of variance, NOT the same shape as grads (will be reduced along last dim)
  92. state["row_var"] = p.grad.new_zeros(row_shape)
  93. col_shape = list(p.grad.shape)
  94. col_shape[-2] = 1
  95. # Col factor of variance, NOT the same shape as grads (will be reduced along penultimate dim)
  96. state["col_var"] = p.grad.new_zeros(col_shape)
  97. else:
  98. state["variance"] = torch.zeros_like(
  99. p.grad, memory_format=torch.preserve_format
  100. )
  101. row_vars.append(state.get("row_var", None))
  102. col_vars.append(state.get("col_var", None))
  103. variances.append(state.get("variance", None))
  104. state_steps.append(state["step"])
  105. return False # has_complex
  106. @torch.no_grad()
  107. def step(self, closure=None):
  108. r"""Perform a single optimization step.
  109. Args:
  110. closure (Callable, optional): A closure that reevaluates the model
  111. and returns the loss.
  112. """
  113. self._accelerator_graph_capture_health_check()
  114. loss = None
  115. if closure is not None:
  116. with torch.enable_grad():
  117. loss = closure()
  118. for group in self.param_groups:
  119. params_with_grad: list[Tensor] = []
  120. grads: list[Tensor] = []
  121. row_vars: list[Tensor | None] = []
  122. col_vars: list[Tensor | None] = []
  123. variances: list[Tensor | None] = []
  124. state_steps: list[Tensor] = []
  125. eps1, eps2 = group["eps"]
  126. has_complex = self._init_group(
  127. group,
  128. params_with_grad,
  129. grads,
  130. row_vars,
  131. col_vars,
  132. variances,
  133. state_steps,
  134. )
  135. adafactor(
  136. params_with_grad,
  137. grads,
  138. row_vars,
  139. col_vars,
  140. variances,
  141. state_steps,
  142. d=group["d"],
  143. lr=group["lr"],
  144. beta2_decay=group["beta2_decay"],
  145. weight_decay=group["weight_decay"],
  146. eps1=eps1,
  147. eps2=eps2,
  148. foreach=group["foreach"],
  149. maximize=group["maximize"],
  150. grad_scale=getattr(self, "grad_scale", None),
  151. found_inf=getattr(self, "found_inf", None),
  152. has_complex=has_complex,
  153. )
  154. return loss
  155. Adafactor.__doc__ = (
  156. r"""Implements Adafactor algorithm.
  157. .. math::
  158. \begin{aligned}
  159. &\rule{110mm}{0.4pt} \\
  160. &\textbf{input} : \gamma \text{(lr)}, \: \tau
  161. \text{(}\beta_2\text{ decay)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \\
  162. &\hspace{15mm} \: \epsilon_1, \epsilon_2 \text{ (epsilons)}, \: d \text{(clipping threshold)}, \\
  163. &\hspace{15mm} \: \lambda \text{(weight decay)},
  164. \: \textit{maximize} \\
  165. &\textbf{initialize} : \: R_0 \leftarrow 0 \text{ (second moment row factor)}, \\
  166. &\hspace{23mm} \: C_0 \leftarrow 0 \text{ (second moment col factor)}, \\
  167. &\hspace{23mm} \: \widehat{V}_0 \leftarrow 0 \text{ (second moment for vectors)} \\[-1.ex]
  168. &\rule{110mm}{0.4pt} \\
  169. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  170. &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
  171. &\hspace{10mm}G_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  172. &\hspace{5mm}\textbf{else} \\
  173. &\hspace{10mm}G_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  174. &\hspace{5mm}\widehat{\beta}_{2_t} \leftarrow 1 - t^{\tau} \\
  175. &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) \\
  176. &\hspace{5mm}\alpha_t \leftarrow max(\epsilon_2,
  177. \text{RMS}(\theta_{t-1}))\rho_t \\
  178. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
  179. &\hspace{5mm}\textbf{if} \: \text{dim}(G_t) > 1: \\
  180. &\hspace{10mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
  181. (1-\widehat{\beta}_{2_t})(G_t \odot G_t) \cdot 1_m \\
  182. &\hspace{10mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
  183. (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t) \\
  184. &\hspace{10mm}\widehat{V}_t \leftarrow
  185. \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\
  186. &\hspace{5mm}\textbf{else} \\
  187. &\hspace{10mm}\widehat{V}_t \leftarrow \widehat{\beta}_{2_t}\widehat{V}_{t-1}+
  188. (1-\widehat{\beta}_{2_t}) \cdot (G_t \odot G_t) \\
  189. &\hspace{5mm}U_t \leftarrow
  190. \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\
  191. &\hspace{5mm}\widehat{U}_t \leftarrow \frac{U_t}{max(1, \frac{\text{RMS}(U_t)}{d})} \\
  192. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \alpha_t \widehat{U}_t \\
  193. &\rule{110mm}{0.4pt} \\[-1.ex]
  194. &\bf{return} \: \theta_t \\[-1.ex]
  195. &\rule{110mm}{0.4pt} \\[-1.ex]
  196. \end{aligned}
  197. For further details regarding the algorithm we refer to `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`_.
  198. """
  199. + rf"""
  200. Args:
  201. {_params_doc}
  202. lr (float, Tensor, optional): unlike other optimizers, Adafactor does not require a
  203. learning rate, and Noam Shazeer and Mitchell Stern do not use lr at all.
  204. Deviating from the paper, this implementation uses lr for applying weight
  205. decay and as the maximum value for relative step size rho_t. Note that in
  206. the paper, a constant of 0.01 is used as the maximum value for relative
  207. step size, and so we set 0.01 as the default value. (default: 1e-2)
  208. beta2_decay (float, optional): the decay rate of beta2. beta2 standardly refers
  209. to the coefficient used for computing the running average of the gradient
  210. squared. (default: -0.8)
  211. eps (Tuple[float, float], optional): epsilon1 is the term added to the denominator
  212. of the update calculation to improve numerical stability. This use of epsilon1
  213. deviates from the algorithm written in the paper! See note below for more details.
  214. epsilon2 is the term used to avoid having too small a weight update when applying
  215. parameter scaling. (default: (None, 1e-3))
  216. d (float, optional): the clipping threshold, used to avoid larger-than-desired
  217. updates.
  218. weight_decay (float, optional): weight decay coefficient (default: 1e-2)
  219. foreach (bool, optional): whether foreach implementation of optimizer is used. Note
  220. that the foreach implementation uses ~ sizeof(params) more peak memory than the
  221. for-loop version due to the intermediates being a tensorlist vs just one tensor.
  222. As Adafactor is commonly used when memory is prohibitive, Adafactor will default
  223. to the slower single tensor for-loop implementation unless this flag is explicitly
  224. True. This behavior is contrary to other optimizers, which will attempt defaulting
  225. to foreach on CUDA for faster runtime. (default: None)
  226. {_maximize_doc}"""
  227. + r"""
  228. .. Note::
  229. The implementation of Adafactor subtly differs from Noam Shazeer and Mitchell Stern
  230. and implementations in some other frameworks with its use of learning rate and
  231. :math:`\epsilon_1`.
  232. Regarding the learning rate hyperparameter: Noam Shazeer and Mitchell Stern do not
  233. use lr at all, as the stated algorithm uses :math:`\rho_t` and update clipping to
  234. affect the step size.
  235. This implementation allows `lr` to influence the maximum value for :math:`\rho_t`:
  236. .. math::
  237. \begin{aligned}
  238. &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}})
  239. \end{aligned}
  240. This differs from Noam Shazeer and Mitchell Stern, who use a constant of 0.01 as
  241. the maximum value of :math:`\rho_t`
  242. .. math::
  243. \begin{aligned}
  244. &\hspace{5mm}\rho_t \leftarrow min(0.01, \frac{1}{\sqrt{t}})
  245. \end{aligned}
  246. Noam Shazeer and Mitchell Stern do not enforce an opinion on how weight decay should
  247. be computed, and so we use the learning rate as a coefficient for decoupled weight
  248. decay, similar to what is suggested in `Decoupled Weight Decay Regularization`_.
  249. Regarding the use of :math:`\epsilon_1`: The implementation attempts to replicate the
  250. presumed intention of Noam Shazeer and Mitchell Stern to use :math:`\epsilon_1` as
  251. a stabilizing term when the squared gradient becomes small.
  252. This stabilization can be written as
  253. .. math::
  254. \begin{aligned}
  255. &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
  256. (1-\widehat{\beta}_{2_t})(G_t \odot G_t + 1_n \cdot 1^\top_m) \cdot 1_m \\
  257. &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
  258. (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + 1_n \cdot 1^\top_m) \\
  259. &\hspace{5mm}\widehat{V}_t \leftarrow
  260. \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\
  261. &\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\
  262. \end{aligned}
  263. where the row and column factors of gradient squared :math:`R_t` and :math:`C_t`
  264. are left alone, and we apply :math:`\epsilon_1` at the final calculation of
  265. the variance estimate :math:`\widehat{V}_t` and for the update :math:`U_t`.
  266. This is in contrast to Noam Shazeer and Mitchell Stern and other frameworks which
  267. apply :math:`\epsilon_1` to both row and column factors of the squared gradient, but
  268. not in the calculations after:
  269. .. math::
  270. \begin{aligned}
  271. &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
  272. (1-\widehat{\beta}_{2_t})(G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \cdot 1_m \\
  273. &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
  274. (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \\
  275. &\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{1^\top_n \cdot R_t} \\
  276. &\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}} \\
  277. \end{aligned}
  278. You may note that Noam Shazeer and Mitchell Stern describe using the sum of squared gradients,
  279. while this implementation uses the mean instead. This choice is mathematically equivalent and
  280. allows for greater numerical stability for large sums.
  281. .. _Adafactor\: Adaptive Learning Rates with Sublinear Memory Cost:
  282. https://arxiv.org/pdf/1804.04235
  283. .. _Decoupled Weight Decay Regularization:
  284. https://arxiv.org/abs/1711.05101
  285. """
  286. )
  287. def _single_tensor_adafactor(
  288. params: list[Tensor],
  289. grads: list[Tensor],
  290. # If grad is 1-dimensional (aka a vector), there is no factorization necessary
  291. # so row_var and col_var will be None while variance will be filled.
  292. # Contrarily, for a grad with multiple dimensions, we will factor along the last
  293. # 2 dimensions, and so row_var and col_var will be filled and variance will be None.
  294. row_vars: list[Tensor | None],
  295. col_vars: list[Tensor | None],
  296. variances: list[Tensor | None],
  297. state_steps: list[Tensor],
  298. grad_scale: Tensor | None,
  299. found_inf: Tensor | None,
  300. *,
  301. d: float,
  302. lr: Tensor | float,
  303. beta2_decay: float,
  304. weight_decay: float,
  305. eps1: float | None,
  306. eps2: float,
  307. maximize: bool,
  308. has_complex: bool,
  309. ) -> None:
  310. if grad_scale is not None or found_inf is not None:
  311. raise AssertionError("Grad scaling should occur outside of optimizer.step()")
  312. if torch.jit.is_scripting():
  313. # this assert is due to JIT being dumb and not realizing that the ops below
  314. # have overloads to handle both float and Tensor lrs, so we just assert it's
  315. # a float since most people using JIT are using floats
  316. if not isinstance(lr, float):
  317. raise AssertionError(f"Expected lr to be a float, but got {type(lr)}")
  318. else:
  319. lr = _to_scalar(lr)
  320. for i, param in enumerate(params):
  321. grad = grads[i] if not maximize else -grads[i]
  322. step_t = state_steps[i]
  323. row_var = row_vars[i]
  324. col_var = col_vars[i]
  325. variance = variances[i]
  326. if eps1 is None:
  327. eps1 = torch.finfo(param.dtype).eps
  328. # update step
  329. step_t += 1
  330. step_float = step_t.item()
  331. one_minus_beta2_t = step_float**beta2_decay
  332. rho_t = min(lr, 1 / (step_float**0.5))
  333. alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
  334. # Perform stepweight decay
  335. if weight_decay != 0:
  336. param.mul_(1 - lr * weight_decay)
  337. if grad.dim() > 1:
  338. if row_var is None or col_var is None:
  339. raise AssertionError(
  340. "row_var and col_var should be defined when grad is multidimensional"
  341. )
  342. # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
  343. row_mean = (
  344. torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))
  345. )
  346. row_var.lerp_(row_mean, one_minus_beta2_t)
  347. # same as (g * g).mean(dim=-2) w/o materializing an intermediate size g
  348. col_mean = (
  349. torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2))
  350. )
  351. col_var.lerp_(col_mean, one_minus_beta2_t)
  352. var_estimate = row_var @ col_var
  353. var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1))
  354. else:
  355. if variance is None:
  356. raise AssertionError("variance should be defined when grad is a vector")
  357. grad_squared = grad * grad
  358. variance.lerp_(grad_squared, one_minus_beta2_t)
  359. # avoid writing into variance during update
  360. var_estimate = variance.clone()
  361. # square the eps1 as we sqrt after to keep eps1's magnitude
  362. update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_()
  363. update.mul_(grad)
  364. denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d))
  365. param.add_(update, alpha=-alpha / denom)
  366. def _group_tensors_by_device_dtype_and_is_multidim(
  367. tensorlists: TensorListList,
  368. ) -> dict[
  369. tuple[torch.device | None, torch.dtype | None, bool],
  370. list[list[Tensor | None]],
  371. ]:
  372. """Groups tensors by device, dtype, AND multidimensionality -- whether the tensor
  373. has multiple dims or just one dim (is a vector). This allows the foreach impl of
  374. Adafactor to assume that every group of params will either be factored or not."""
  375. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(tensorlists)
  376. ultra_grouped_tensors: dict[
  377. tuple[torch.device | None, torch.dtype | None, bool],
  378. list[list[Tensor | None]],
  379. ] = {}
  380. for (device, dtype), (tensorlists, _) in grouped_tensors.items():
  381. matrix_key = (device, dtype, True)
  382. vector_key = (device, dtype, False)
  383. # assumes grad is the second tensorlist
  384. for j, tensor in enumerate(tensorlists[1]):
  385. if tensor is None:
  386. raise AssertionError("grad should not be None")
  387. if tensor.dim() > 1:
  388. if matrix_key not in ultra_grouped_tensors:
  389. ultra_grouped_tensors[matrix_key] = [[] for _ in tensorlists]
  390. for i in range(len(tensorlists)):
  391. ultra_grouped_tensors[matrix_key][i].append(tensorlists[i][j])
  392. else:
  393. if vector_key not in ultra_grouped_tensors:
  394. ultra_grouped_tensors[vector_key] = [[] for _ in tensorlists]
  395. for i in range(len(tensorlists)):
  396. ultra_grouped_tensors[vector_key][i].append(tensorlists[i][j])
  397. return ultra_grouped_tensors
  398. def _multi_tensor_adafactor(
  399. params: list[Tensor],
  400. grads: list[Tensor],
  401. # If grad is 1-dimensional (aka a vector), there is no factorization necessary
  402. # so row_var and col_var will be None while variance will be filled.
  403. # Contrarily, for a grad with multiple dimensions, we will factor along the last
  404. # 2 dimensions, and so row_var and col_var will be filled and variance will be None.
  405. row_vars: list[Tensor | None],
  406. col_vars: list[Tensor | None],
  407. variances: list[Tensor | None],
  408. state_steps: list[Tensor],
  409. grad_scale: Tensor | None,
  410. found_inf: Tensor | None,
  411. *,
  412. d: float,
  413. lr: Tensor | float,
  414. beta2_decay: float,
  415. weight_decay: float,
  416. eps1: float | None,
  417. eps2: float,
  418. maximize: bool,
  419. has_complex: bool,
  420. ) -> None:
  421. if len(params) == 0:
  422. return
  423. if grad_scale is not None or found_inf is not None:
  424. raise AssertionError("Grad scaling should occur outside of optimizer.step()")
  425. lr = _to_scalar(lr)
  426. grouped_tensors = _group_tensors_by_device_dtype_and_is_multidim(
  427. [params, grads, row_vars, col_vars, variances, state_steps] # type: ignore[list-item]
  428. )
  429. for (_, dtype, is_multidim), (
  430. (
  431. device_params_,
  432. device_grads_,
  433. device_row_vars_,
  434. device_col_vars_,
  435. device_variances_,
  436. device_state_steps_,
  437. )
  438. ) in grouped_tensors.items():
  439. device_params = cast(list[Tensor], device_params_)
  440. device_grads = cast(list[Tensor], device_grads_)
  441. device_state_steps = cast(list[Tensor], device_state_steps_)
  442. if eps1 is None:
  443. if dtype is None:
  444. raise AssertionError(
  445. "dtype is needed to compute eps1 when eps1 is unset"
  446. )
  447. eps1 = torch.finfo(dtype).eps
  448. if TYPE_CHECKING:
  449. assert device_state_steps[0] is not None
  450. if maximize:
  451. device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
  452. # Update steps
  453. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  454. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  455. # wrapped it once now. The alpha is required to assure we go to the right overload.
  456. if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
  457. torch._foreach_add_(
  458. device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  459. )
  460. else:
  461. torch._foreach_add_(device_state_steps, 1.0)
  462. one_minus_beta2_ts = []
  463. beta2_ts = []
  464. rho_ts = []
  465. for s in device_state_steps:
  466. one_minus_beta2_ts.append(s.item() ** beta2_decay)
  467. beta2_ts.append(1 - s.item() ** beta2_decay)
  468. rho_ts.append(min(lr, 1 / (s.item() ** 0.5)))
  469. alphas = [
  470. max(eps2, p.norm(2).item() / (p.numel() ** 0.5)) * r
  471. for p, r in zip(device_params, rho_ts, strict=True)
  472. ]
  473. # Perform stepweight decay
  474. if weight_decay != 0:
  475. torch._foreach_mul_(device_params, 1 - lr * weight_decay)
  476. if is_multidim:
  477. device_row_vars = cast(list[Tensor], device_row_vars_)
  478. device_col_vars = cast(list[Tensor], device_col_vars_)
  479. if device_row_vars[0] is None or device_col_vars[0] is None:
  480. raise AssertionError(
  481. "row_var and col_var should be defined when grad is multidimensional"
  482. )
  483. # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
  484. row_means = [
  485. torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads
  486. ]
  487. torch._foreach_mul_(row_means, row_means)
  488. torch._foreach_div_(row_means, [grad.size(-1) for grad in device_grads])
  489. torch._foreach_lerp_(device_row_vars, row_means, one_minus_beta2_ts)
  490. del row_means
  491. # same as (g * g).mean(dim=-2) w/o materializing an intermediate size g
  492. col_means = [
  493. torch.norm(grad, dim=-2, keepdim=True) for grad in device_grads
  494. ]
  495. torch._foreach_mul_(col_means, col_means)
  496. torch._foreach_div_(col_means, [grad.size(-2) for grad in device_grads])
  497. torch._foreach_lerp_(device_col_vars, col_means, one_minus_beta2_ts)
  498. del col_means
  499. var_estimates = [
  500. row_var @ col_var
  501. for row_var, col_var in zip(
  502. device_row_vars, device_col_vars, strict=True
  503. )
  504. ]
  505. row_var_means = [
  506. row_var.mean(dim=-2, keepdim=True) for row_var in device_row_vars
  507. ]
  508. torch._foreach_clamp_min_(row_var_means, eps1)
  509. torch._foreach_div_(var_estimates, row_var_means)
  510. del row_var_means
  511. else:
  512. device_variances = cast(list[Tensor], device_variances_)
  513. if device_variances[0] is None:
  514. raise AssertionError("variance should be defined when grad is a vector")
  515. grads_squared = torch._foreach_mul(device_grads, device_grads)
  516. torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts)
  517. del grads_squared
  518. # avoid writing into variance during update
  519. var_estimates = [v.clone() for v in device_variances]
  520. # square the eps1 as we sqrt after to keep eps1's magnitude
  521. torch._foreach_clamp_min_(var_estimates, eps1 * eps1)
  522. torch._foreach_rsqrt_(var_estimates)
  523. torch._foreach_mul_(var_estimates, device_grads)
  524. updates = var_estimates
  525. alphas = [
  526. -a / (max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d)))
  527. for a, update in zip(alphas, updates, strict=True)
  528. ]
  529. torch._foreach_mul_(updates, alphas)
  530. torch._foreach_add_(device_params, updates)
  531. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adafactor)
  532. def adafactor(
  533. params: list[Tensor],
  534. grads: list[Tensor],
  535. row_vars: list[Tensor | None],
  536. col_vars: list[Tensor | None],
  537. variances: list[Tensor | None],
  538. state_steps: list[Tensor],
  539. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  540. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  541. foreach: bool | None = None,
  542. grad_scale: Tensor | None = None,
  543. found_inf: Tensor | None = None,
  544. has_complex: bool = False,
  545. *,
  546. d: float,
  547. lr: float | Tensor,
  548. beta2_decay: float,
  549. weight_decay: float,
  550. eps1: float,
  551. eps2: float,
  552. maximize: bool,
  553. ) -> None:
  554. r"""Functional API that performs Adafactor algorithm computation.
  555. See :class:`~torch.optim.Adafactor` for details.
  556. """
  557. if not torch.compiler.is_compiling() and not all(
  558. isinstance(t, torch.Tensor) for t in state_steps
  559. ):
  560. raise RuntimeError(
  561. "`state_steps` argument must contain a list of singleton tensors"
  562. )
  563. if foreach:
  564. func = _multi_tensor_adafactor
  565. else:
  566. func = _single_tensor_adafactor
  567. func(
  568. params,
  569. grads,
  570. row_vars,
  571. col_vars,
  572. variances,
  573. state_steps,
  574. d=d,
  575. lr=lr,
  576. beta2_decay=beta2_decay,
  577. weight_decay=weight_decay,
  578. eps1=eps1,
  579. eps2=eps2,
  580. maximize=maximize,
  581. grad_scale=grad_scale,
  582. found_inf=found_inf,
  583. has_complex=has_complex,
  584. )