adam.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991
  1. # mypy: allow-untyped-defs
  2. from typing import cast
  3. import torch
  4. from torch import Tensor
  5. from .optimizer import (
  6. _capturable_doc,
  7. _default_to_fused_or_foreach,
  8. _device_dtype_check_for_fused,
  9. _differentiable_doc,
  10. _disable_dynamo_if_unsupported,
  11. _foreach_doc,
  12. _fused_doc,
  13. _get_capturable_supported_devices,
  14. _get_scalar_dtype,
  15. _get_value,
  16. _maximize_doc,
  17. _params_doc,
  18. _stack_if_compiling,
  19. _to_scalar,
  20. _use_grad_for_differentiable,
  21. _view_as_real,
  22. DeviceDict,
  23. DeviceDtypeDict,
  24. Optimizer,
  25. ParamsT,
  26. )
  27. __all__ = ["Adam", "adam"]
  28. class Adam(Optimizer):
  29. def __init__(
  30. self,
  31. params: ParamsT,
  32. lr: float | Tensor = 1e-3,
  33. betas: tuple[float | Tensor, float | Tensor] = (0.9, 0.999),
  34. eps: float = 1e-8,
  35. weight_decay: float = 0,
  36. amsgrad: bool = False,
  37. *,
  38. foreach: bool | None = None,
  39. maximize: bool = False,
  40. capturable: bool = False,
  41. differentiable: bool = False,
  42. fused: bool | None = None,
  43. decoupled_weight_decay: bool = False,
  44. ) -> None:
  45. if isinstance(lr, Tensor):
  46. if foreach and not capturable:
  47. raise ValueError(
  48. "lr as a Tensor is not supported for capturable=False and foreach=True"
  49. )
  50. if lr.numel() != 1:
  51. raise ValueError("Tensor lr must be 1-element")
  52. if not 0.0 <= lr:
  53. raise ValueError(f"Invalid learning rate: {lr}")
  54. if not 0.0 <= eps:
  55. raise ValueError(f"Invalid epsilon value: {eps}")
  56. if not 0.0 <= betas[0] < 1.0:
  57. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  58. if not 0.0 <= betas[1] < 1.0:
  59. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  60. if not 0.0 <= weight_decay:
  61. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  62. if not (
  63. (isinstance(betas[0], float) and isinstance(betas[1], float))
  64. or (isinstance(betas[0], Tensor) and isinstance(betas[1], Tensor))
  65. ):
  66. raise ValueError("betas must be either both floats or both Tensors")
  67. if isinstance(betas[0], Tensor):
  68. if not capturable and foreach:
  69. raise ValueError(
  70. "betas[0] as a Tensor is not supported for capturable=False and foreach=True"
  71. )
  72. if betas[0].numel() != 1:
  73. raise ValueError("Tensor betas[0] must be 1-element")
  74. if isinstance(betas[1], Tensor):
  75. if not capturable and foreach:
  76. raise ValueError(
  77. "betas[1] as a Tensor is not supported for capturable=False and foreach=True"
  78. )
  79. if betas[1].numel() != 1:
  80. raise ValueError("Tensor betas[1] must be 1-element")
  81. betas = tuple(map(_to_scalar, betas))
  82. defaults = {
  83. "lr": lr,
  84. "betas": betas,
  85. "eps": eps,
  86. "weight_decay": weight_decay,
  87. "amsgrad": amsgrad,
  88. "maximize": maximize,
  89. "foreach": foreach,
  90. "capturable": capturable,
  91. "differentiable": differentiable,
  92. "fused": fused,
  93. "decoupled_weight_decay": decoupled_weight_decay,
  94. }
  95. super().__init__(params, defaults)
  96. if fused:
  97. if differentiable:
  98. raise RuntimeError("`fused` does not support `differentiable`")
  99. self._step_supports_amp_scaling = True
  100. # TODO(crcrpar): [low prec params & their higher prec copy]
  101. # Support AMP with FP16/BF16 model params which would need
  102. # higher prec copy of params to do update math in higher prec to
  103. # alleviate the loss of information.
  104. if foreach:
  105. raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
  106. def __setstate__(self, state):
  107. super().__setstate__(state)
  108. for group in self.param_groups:
  109. group.setdefault("amsgrad", False)
  110. group.setdefault("maximize", False)
  111. group.setdefault("foreach", None)
  112. group.setdefault("capturable", False)
  113. group.setdefault("differentiable", False)
  114. group.setdefault("decoupled_weight_decay", False)
  115. fused = group.setdefault("fused", None)
  116. for p in group["params"]:
  117. p_state = self.state.get(p, [])
  118. if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
  119. step_val = float(p_state["step"])
  120. p_state["step"] = (
  121. torch.tensor(
  122. step_val,
  123. dtype=_get_scalar_dtype(is_fused=fused),
  124. device=p.device,
  125. )
  126. if group["capturable"] or group["fused"]
  127. else torch.tensor(step_val, dtype=_get_scalar_dtype())
  128. )
  129. def _init_group(
  130. self,
  131. group,
  132. params_with_grad,
  133. grads,
  134. exp_avgs,
  135. exp_avg_sqs,
  136. max_exp_avg_sqs,
  137. state_steps,
  138. ):
  139. has_complex = False
  140. for p in group["params"]:
  141. if p.grad is not None:
  142. has_complex |= torch.is_complex(p)
  143. params_with_grad.append(p)
  144. if p.grad.is_sparse:
  145. raise RuntimeError(
  146. "Adam does not support sparse gradients, please consider SparseAdam instead"
  147. )
  148. grads.append(p.grad)
  149. state = self.state[p]
  150. # Lazy state initialization
  151. if len(state) == 0:
  152. if group["fused"]:
  153. _device_dtype_check_for_fused(p)
  154. # note(crcrpar): [special device hosting for step]
  155. # Deliberately host `step` on CPU if both capturable and fused are off.
  156. # This is because kernel launches are costly on CUDA and XLA.
  157. state["step"] = (
  158. torch.zeros(
  159. (),
  160. dtype=_get_scalar_dtype(is_fused=group["fused"]),
  161. device=p.device,
  162. )
  163. if group["capturable"] or group["fused"]
  164. else torch.tensor(0.0, dtype=_get_scalar_dtype())
  165. )
  166. # Exponential moving average of gradient values
  167. state["exp_avg"] = torch.zeros_like(
  168. p, memory_format=torch.preserve_format
  169. )
  170. # Exponential moving average of squared gradient values
  171. state["exp_avg_sq"] = torch.zeros_like(
  172. p, memory_format=torch.preserve_format
  173. )
  174. if group["amsgrad"]:
  175. # Maintains max of all exp. moving avg. of sq. grad. values
  176. state["max_exp_avg_sq"] = torch.zeros_like(
  177. p, memory_format=torch.preserve_format
  178. )
  179. exp_avgs.append(state["exp_avg"])
  180. exp_avg_sqs.append(state["exp_avg_sq"])
  181. if group["amsgrad"]:
  182. max_exp_avg_sqs.append(state["max_exp_avg_sq"])
  183. if group["differentiable"] and state["step"].requires_grad:
  184. raise RuntimeError(
  185. "`requires_grad` is not supported for `step` in differentiable mode"
  186. )
  187. # Foreach without capturable does not support a tensor lr
  188. if (
  189. group["foreach"]
  190. and torch.is_tensor(group["lr"])
  191. and not group["capturable"]
  192. ):
  193. raise RuntimeError(
  194. "lr as a Tensor is not supported for capturable=False and foreach=True"
  195. )
  196. state_steps.append(state["step"])
  197. return has_complex
  198. @_use_grad_for_differentiable
  199. def step(self, closure=None):
  200. """Perform a single optimization step.
  201. Args:
  202. closure (Callable, optional): A closure that reevaluates the model
  203. and returns the loss.
  204. """
  205. self._accelerator_graph_capture_health_check()
  206. loss = None
  207. if closure is not None:
  208. with torch.enable_grad():
  209. loss = closure()
  210. for group in self.param_groups:
  211. params_with_grad: list[Tensor] = []
  212. grads: list[Tensor] = []
  213. exp_avgs: list[Tensor] = []
  214. exp_avg_sqs: list[Tensor] = []
  215. max_exp_avg_sqs: list[Tensor] = []
  216. state_steps: list[Tensor] = []
  217. beta1, beta2 = group["betas"]
  218. has_complex = self._init_group(
  219. group,
  220. params_with_grad,
  221. grads,
  222. exp_avgs,
  223. exp_avg_sqs,
  224. max_exp_avg_sqs,
  225. state_steps,
  226. )
  227. adam(
  228. params_with_grad,
  229. grads,
  230. exp_avgs,
  231. exp_avg_sqs,
  232. max_exp_avg_sqs,
  233. state_steps,
  234. amsgrad=group["amsgrad"],
  235. has_complex=has_complex,
  236. beta1=beta1,
  237. beta2=beta2,
  238. lr=group["lr"],
  239. weight_decay=group["weight_decay"],
  240. eps=group["eps"],
  241. maximize=group["maximize"],
  242. foreach=group["foreach"],
  243. capturable=group["capturable"],
  244. differentiable=group["differentiable"],
  245. fused=group["fused"],
  246. grad_scale=getattr(self, "grad_scale", None),
  247. found_inf=getattr(self, "found_inf", None),
  248. decoupled_weight_decay=group["decoupled_weight_decay"],
  249. )
  250. return loss
  251. Adam.__doc__ = (
  252. r"""Implements Adam algorithm.
  253. .. math::
  254. \begin{aligned}
  255. &\rule{110mm}{0.4pt} \\
  256. &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
  257. \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
  258. &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
  259. \:\textit{maximize}, \: \epsilon \text{ (epsilon)} \\
  260. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  261. v_0\leftarrow 0 \text{ (second moment)},\: v_0^{max}\leftarrow 0 \\[-1.ex]
  262. &\rule{110mm}{0.4pt} \\
  263. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  264. &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
  265. &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  266. &\hspace{5mm}\textbf{else} \\
  267. &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  268. &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
  269. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  270. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  271. &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  272. &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
  273. &\hspace{5mm}\textbf{if} \: amsgrad \\
  274. &\hspace{10mm} v_t^{max} \leftarrow \mathrm{max}(v_{t-1}^{max},v_t) \\
  275. &\hspace{10mm}\widehat{v_t} \leftarrow v_t^{max}/\big(1-\beta_2^t \big) \\
  276. &\hspace{5mm}\textbf{else} \\
  277. &\hspace{10mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
  278. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
  279. \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
  280. &\rule{110mm}{0.4pt} \\[-1.ex]
  281. &\bf{return} \: \theta_t \\[-1.ex]
  282. &\rule{110mm}{0.4pt} \\[-1.ex]
  283. \end{aligned}
  284. For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
  285. """
  286. + rf"""
  287. Args:
  288. {_params_doc}
  289. lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
  290. is not yet supported for all our implementations. Please use a float
  291. LR if you are not also specifying fused=True or capturable=True.
  292. betas (tuple[Union[float, Tensor], Union[float, Tensor]], optional):
  293. coefficients used for computing running averages of gradient and
  294. its square. If a tensor is provided, must be 1-element. (default: (0.9, 0.999))
  295. eps (float, optional): term added to the denominator to improve
  296. numerical stability (default: 1e-8)
  297. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  298. decoupled_weight_decay (bool, optional): if True, this optimizer is
  299. equivalent to AdamW and the algorithm will not accumulate weight
  300. decay in the momentum nor variance. (default: False)
  301. amsgrad (bool, optional): whether to use the AMSGrad variant of this
  302. algorithm from the paper `On the Convergence of Adam and Beyond`_
  303. (default: False)
  304. {_foreach_doc}
  305. {_maximize_doc}
  306. {_capturable_doc}
  307. {_differentiable_doc}
  308. {_fused_doc}
  309. .. Note::
  310. A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
  311. .. _Adam\: A Method for Stochastic Optimization:
  312. https://arxiv.org/abs/1412.6980
  313. .. _On the Convergence of Adam and Beyond:
  314. https://openreview.net/forum?id=ryQu7f-RZ
  315. """
  316. )
  317. def _single_tensor_adam(
  318. params: list[Tensor],
  319. grads: list[Tensor],
  320. exp_avgs: list[Tensor],
  321. exp_avg_sqs: list[Tensor],
  322. max_exp_avg_sqs: list[Tensor],
  323. state_steps: list[Tensor],
  324. grad_scale: Tensor | None,
  325. found_inf: Tensor | None,
  326. *,
  327. amsgrad: bool,
  328. has_complex: bool,
  329. beta1: float | Tensor,
  330. beta2: float | Tensor,
  331. lr: float | Tensor,
  332. weight_decay: float,
  333. eps: float,
  334. maximize: bool,
  335. capturable: bool,
  336. differentiable: bool,
  337. decoupled_weight_decay: bool,
  338. ) -> None:
  339. if grad_scale is not None or found_inf is not None:
  340. raise AssertionError("Expected grad_scale and found_inf to be None")
  341. if torch.jit.is_scripting():
  342. # this assert is due to JIT being dumb and not realizing that the ops below
  343. # have overloads to handle both float and Tensor lrs, so we just assert it's
  344. # a float since most people using JIT are using floats
  345. if not isinstance(lr, float):
  346. raise AssertionError(f"Expected lr to be a float, but got {type(lr)}")
  347. if not isinstance(beta1, float):
  348. raise AssertionError(f"Expected beta1 to be a float, but got {type(beta1)}")
  349. if not isinstance(beta2, float):
  350. raise AssertionError(f"Expected beta2 to be a float, but got {type(beta2)}")
  351. else:
  352. lr = _to_scalar(lr)
  353. beta1 = _to_scalar(beta1)
  354. beta2 = _to_scalar(beta2)
  355. # We only shuffle around the beta when it is a Tensor, otherwise, we prefer
  356. # treating it as a scalar.
  357. # Note: ensure type declaration is under conditional check for isinstance
  358. # or else torchscript will get cranky about the DeviceDict type.
  359. if isinstance(beta1, Tensor):
  360. beta1_dict: DeviceDtypeDict | None = {(beta1.device, beta1.dtype): beta1}
  361. else:
  362. beta1_dict = None
  363. for i, param in enumerate(params):
  364. grad = grads[i] if not maximize else -grads[i]
  365. exp_avg = exp_avgs[i]
  366. exp_avg_sq = exp_avg_sqs[i]
  367. step_t = state_steps[i]
  368. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  369. if not torch.compiler.is_compiling() and capturable:
  370. capturable_supported_devices = _get_capturable_supported_devices()
  371. if not (
  372. param.device.type == step_t.device.type
  373. and param.device.type in capturable_supported_devices
  374. ):
  375. raise AssertionError(
  376. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  377. )
  378. # update step
  379. step_t += 1
  380. if weight_decay != 0:
  381. if decoupled_weight_decay:
  382. # Perform stepweight decay
  383. param.mul_(1 - lr * weight_decay)
  384. else:
  385. # Nested if is necessary to bypass jitscript rules
  386. if differentiable and isinstance(weight_decay, Tensor):
  387. if weight_decay.requires_grad:
  388. grad = grad.addcmul_(param.clone(), weight_decay)
  389. else:
  390. # pyrefly: ignore [bad-argument-type]
  391. grad = grad.add(param, alpha=weight_decay)
  392. else:
  393. grad = grad.add(param, alpha=weight_decay)
  394. if torch.is_complex(param):
  395. grad = torch.view_as_real(grad)
  396. exp_avg = torch.view_as_real(exp_avg)
  397. exp_avg_sq = torch.view_as_real(exp_avg_sq)
  398. if amsgrad:
  399. max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
  400. param = torch.view_as_real(param)
  401. device = param.device
  402. if beta1_dict is not None:
  403. dtype = param.dtype # type: ignore[union-attr]
  404. # cast to workaround https://github.com/pytorch/pytorch/issues/140601
  405. key = (device, dtype)
  406. if key not in beta1_dict:
  407. beta1_dict[key] = beta1.to( # type: ignore[union-attr]
  408. device=device, dtype=dtype, non_blocking=True
  409. )
  410. device_beta1: float | Tensor = beta1_dict[key]
  411. else:
  412. device_beta1 = beta1
  413. # Decay the first and second moment running average coefficient
  414. exp_avg.lerp_(grad, 1 - device_beta1)
  415. # Nested if is necessary to bypass jitscript rules
  416. if differentiable and isinstance(beta2, Tensor):
  417. if beta2.requires_grad:
  418. # Using lerp to only use 2 operations bc addcmul's value cannot be a tensor
  419. # Showing equivalence of differentiable path and nondifferentiable path
  420. # expavg * b2 + grad^2 * (1-b2)
  421. # add expavg * (1-b2) - expavg * (1-b2) = 0
  422. # expavg * b2 + expavg * (1-b2) - expavg * (1-b2) + grad^2 * (1-b2)
  423. # expavg - expavg * (1-b2) + grad^2 * (1-b2)
  424. # expavg + (grad^2 - expavg) * (1-b2)
  425. # expavg.lerp(grad^2, 1-beta2)
  426. exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
  427. else:
  428. exp_avg_sq.mul_(beta2).addcmul_(
  429. grad, grad, value=cast(float, 1 - beta2)
  430. )
  431. else:
  432. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # type: ignore[arg-type]
  433. if capturable or differentiable:
  434. step = step_t
  435. # Nested if is necessary to bypass jitscript rules
  436. if differentiable and isinstance(beta1, Tensor):
  437. if beta1.requires_grad:
  438. bias_correction1 = 1 - beta1 ** step.clone()
  439. else:
  440. bias_correction1 = 1 - beta1**step
  441. else:
  442. bias_correction1 = 1 - beta1**step
  443. # Nested if is necessary to bypass jitscript rules
  444. if differentiable and isinstance(beta2, Tensor):
  445. if beta2.requires_grad:
  446. bias_correction2 = 1 - beta2 ** step.clone()
  447. else:
  448. bias_correction2 = 1 - beta2**step
  449. else:
  450. bias_correction2 = 1 - beta2**step
  451. step_size = lr / bias_correction1
  452. step_size_neg = step_size.neg()
  453. bias_correction2_sqrt = bias_correction2.sqrt()
  454. if amsgrad:
  455. # Maintains the maximum of all 2nd moment running avg. till now
  456. if differentiable:
  457. max_exp_avg_sq = max_exp_avg_sqs[i].clone()
  458. else:
  459. max_exp_avg_sq = max_exp_avg_sqs[i]
  460. max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
  461. # Uses the max. for normalizing running avg. of gradient
  462. # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
  463. # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
  464. denom = (
  465. max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
  466. ).add_(eps / step_size_neg)
  467. else:
  468. denom = (
  469. exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
  470. ).add_(eps / step_size_neg)
  471. if differentiable:
  472. param.addcdiv_(exp_avg.clone(), denom)
  473. else:
  474. param.addcdiv_(exp_avg, denom)
  475. else:
  476. step = _get_value(step_t)
  477. bias_correction1 = 1 - beta1**step
  478. bias_correction2 = 1 - beta2**step
  479. step_size = lr / bias_correction1
  480. bias_correction2_sqrt = bias_correction2**0.5
  481. if amsgrad:
  482. # Maintains the maximum of all 2nd moment running avg. till now
  483. torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
  484. # Use the max. for normalizing running avg. of gradient
  485. denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
  486. else:
  487. denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
  488. param.addcdiv_(exp_avg, denom, value=-step_size) # type: ignore[arg-type]
  489. # Lastly, switch back to complex view
  490. if amsgrad and torch.is_complex(params[i]):
  491. max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
  492. def _multi_tensor_adam(
  493. params: list[Tensor],
  494. grads: list[Tensor],
  495. exp_avgs: list[Tensor],
  496. exp_avg_sqs: list[Tensor],
  497. max_exp_avg_sqs: list[Tensor],
  498. state_steps: list[Tensor],
  499. grad_scale: Tensor | None,
  500. found_inf: Tensor | None,
  501. *,
  502. amsgrad: bool,
  503. has_complex: bool,
  504. beta1: float | Tensor,
  505. beta2: float | Tensor,
  506. lr: float | Tensor,
  507. weight_decay: float,
  508. eps: float,
  509. maximize: bool,
  510. capturable: bool,
  511. differentiable: bool,
  512. decoupled_weight_decay: bool,
  513. ) -> None:
  514. if len(params) == 0:
  515. return
  516. if isinstance(lr, Tensor):
  517. if not capturable:
  518. raise RuntimeError(
  519. "lr as a Tensor is not supported for capturable=False and foreach=True"
  520. )
  521. if lr.numel() != 1:
  522. raise ValueError("Tensor lr must be 1-element")
  523. if isinstance(beta1, Tensor):
  524. if not capturable:
  525. raise ValueError(
  526. "beta1 as a Tensor is not supported for capturable=False and foreach=True"
  527. )
  528. if beta1.numel() != 1:
  529. raise ValueError("Tensor beta1 must be 1-element")
  530. if isinstance(beta2, Tensor):
  531. if not capturable:
  532. raise ValueError(
  533. "beta2 as a Tensor is not supported for capturable=False and foreach=True"
  534. )
  535. if beta2.numel() != 1:
  536. raise ValueError("Tensor beta2 must be 1-element")
  537. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  538. if not torch.compiler.is_compiling() and capturable:
  539. capturable_supported_devices = _get_capturable_supported_devices(
  540. supports_xla=False
  541. )
  542. if not all(
  543. p.device.type == step.device.type
  544. and p.device.type in capturable_supported_devices
  545. for p, step in zip(params, state_steps, strict=True)
  546. ):
  547. raise AssertionError(
  548. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  549. )
  550. if grad_scale is not None or found_inf is not None:
  551. raise AssertionError("Expected grad_scale and found_inf to be None")
  552. if differentiable:
  553. raise AssertionError("_foreach ops don't support autograd")
  554. lr = _to_scalar(lr)
  555. beta1 = _to_scalar(beta1)
  556. beta2 = _to_scalar(beta2)
  557. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  558. [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
  559. )
  560. # We only shuffle around the beta when it is a Tensor and on CUDA, otherwise, we prefer
  561. # treating it as a scalar.
  562. beta1_dict: DeviceDict | None = ( # type: ignore[attr-defined]
  563. {beta1.device: beta1}
  564. if isinstance(beta1, Tensor) and str(beta1.device) != "cpu"
  565. else None
  566. )
  567. for (
  568. device_params_,
  569. device_grads_,
  570. device_exp_avgs_,
  571. device_exp_avg_sqs_,
  572. device_max_exp_avg_sqs_,
  573. device_state_steps_,
  574. ), _ in grouped_tensors.values():
  575. device_params = cast(list[Tensor], device_params_)
  576. device_grads = cast(list[Tensor], device_grads_)
  577. device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
  578. device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
  579. device_state_steps = cast(list[Tensor], device_state_steps_)
  580. device = device_params[0].device
  581. if beta1_dict is not None and device not in beta1_dict:
  582. beta1_dict[device] = beta1.to(device=device, non_blocking=True) # type: ignore[union-attr, attr-defined]
  583. device_beta1 = beta1_dict[device] if beta1_dict else beta1
  584. # Handle complex parameters
  585. if has_complex:
  586. if amsgrad:
  587. device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
  588. _view_as_real(
  589. device_params,
  590. device_grads,
  591. device_exp_avgs,
  592. device_exp_avg_sqs,
  593. device_max_exp_avg_sqs,
  594. )
  595. else:
  596. _view_as_real(
  597. device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
  598. )
  599. if maximize:
  600. device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
  601. # Update steps
  602. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  603. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  604. # wrapped it once now. The alpha is required to assure we go to the right overload.
  605. if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
  606. torch._foreach_add_(
  607. device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  608. )
  609. else:
  610. torch._foreach_add_(device_state_steps, 1)
  611. if weight_decay != 0:
  612. if decoupled_weight_decay:
  613. # Perform stepweight decay
  614. torch._foreach_mul_(device_params, 1 - lr * weight_decay)
  615. else:
  616. # Reuse the intermediate memory (device_grads) already allocated for maximize
  617. if maximize:
  618. torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
  619. else:
  620. device_grads = torch._foreach_add( # type: ignore[assignment]
  621. device_grads, device_params, alpha=weight_decay
  622. )
  623. # Decay the first and second moment running average coefficient
  624. # Use device beta1 if beta1 is a tensor to ensure all
  625. # tensors are on the same device
  626. torch._foreach_lerp_(
  627. device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
  628. )
  629. torch._foreach_mul_(device_exp_avg_sqs, beta2)
  630. # Due to the strictness of the _foreach_addcmul API, we can't have a single
  631. # tensor scalar as the scalar arg (only python number is supported there)
  632. # as a result, separate out the value mul
  633. # Filed https://github.com/pytorch/pytorch/issues/139795
  634. if isinstance(beta2, torch.Tensor):
  635. scaled_device_grads = torch._foreach_mul(device_grads, 1 - beta2) # type: ignore[assignment]
  636. value = 1.0
  637. else:
  638. scaled_device_grads = device_grads # type: ignore[assignment]
  639. value = 1 - beta2
  640. torch._foreach_addcmul_(
  641. device_exp_avg_sqs, scaled_device_grads, device_grads, value
  642. )
  643. # Delete the local intermediate(s) since they won't be used anymore to save on peak memory
  644. del device_grads
  645. del scaled_device_grads
  646. bias_correction1: tuple[Tensor, ...] | list[Tensor]
  647. bias_correction2: tuple[Tensor, ...] | list[Tensor]
  648. bias_correction2_sqrt: tuple[Tensor, ...] | list[Tensor]
  649. if capturable:
  650. bias_correction1 = torch._foreach_pow(beta1, device_state_steps) # type: ignore[arg-type]
  651. bias_correction2 = torch._foreach_pow(beta2, device_state_steps) # type: ignore[arg-type]
  652. # foreach_sub doesn't allow a scalar as the first arg
  653. torch._foreach_sub_(bias_correction1, 1)
  654. torch._foreach_sub_(bias_correction2, 1)
  655. # we do not negate bias_correction1 as it'll need to be negated later anyway
  656. torch._foreach_neg_(bias_correction2)
  657. # foreach_div doesn't allow a scalar as the first arg
  658. torch._foreach_div_(bias_correction1, lr)
  659. torch._foreach_reciprocal_(bias_correction1)
  660. torch._foreach_sqrt_(bias_correction2)
  661. # Re-assign for clarity as we maintain minimal intermediates: we'll have
  662. # step_size = - lr / (1 - beta1 ^ t) where t = num_steps
  663. # bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
  664. step_size = bias_correction1
  665. bias_correction2_sqrt = bias_correction2
  666. if amsgrad:
  667. device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
  668. # Maintains the maximum of all 2nd moment running avg. till now
  669. torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment]
  670. # Set intermediate to the max. for normalizing running avg. of gradient when amsgrad
  671. exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
  672. else:
  673. exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
  674. torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
  675. torch._foreach_add_(exp_avg_sq_sqrt, eps)
  676. torch._foreach_div_(exp_avg_sq_sqrt, step_size)
  677. # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
  678. torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
  679. else:
  680. bias_correction1 = [
  681. 1 - beta1 ** _get_value(step) for step in device_state_steps
  682. ]
  683. bias_correction2 = [
  684. 1 - beta2 ** _get_value(step) for step in device_state_steps
  685. ]
  686. step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
  687. bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type]
  688. if amsgrad:
  689. device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
  690. # Maintains the maximum of all 2nd moment running avg. till now
  691. torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
  692. # Use the max. for normalizing running avg. of gradient
  693. exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
  694. else:
  695. exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
  696. torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
  697. torch._foreach_add_(exp_avg_sq_sqrt, eps)
  698. torch._foreach_addcdiv_(
  699. device_params,
  700. device_exp_avgs,
  701. exp_avg_sq_sqrt,
  702. step_size, # type: ignore[arg-type]
  703. )
  704. def _fused_adam(
  705. params: list[Tensor],
  706. grads: list[Tensor],
  707. exp_avgs: list[Tensor],
  708. exp_avg_sqs: list[Tensor],
  709. max_exp_avg_sqs: list[Tensor],
  710. state_steps: list[Tensor],
  711. grad_scale: Tensor | None,
  712. found_inf: Tensor | None,
  713. *,
  714. amsgrad: bool,
  715. has_complex: bool, # Needed for consistency.
  716. beta1: float | Tensor,
  717. beta2: float | Tensor,
  718. lr: float | Tensor,
  719. weight_decay: float,
  720. eps: float,
  721. maximize: bool,
  722. capturable: bool, # Needed for consistency.
  723. differentiable: bool,
  724. decoupled_weight_decay: bool,
  725. ) -> None:
  726. if not params:
  727. return
  728. if differentiable:
  729. raise RuntimeError("Adam with fused=True does not support differentiable=True")
  730. beta1 = _to_scalar(beta1)
  731. beta2 = _to_scalar(beta2)
  732. grad_scale_dict: DeviceDict = (
  733. {grad_scale.device: grad_scale} if grad_scale is not None else {}
  734. )
  735. found_inf_dict: DeviceDict = (
  736. {found_inf.device: found_inf} if found_inf is not None else {}
  737. )
  738. # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
  739. # treating it as a scalar.
  740. lr_dict: DeviceDict | None = (
  741. {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
  742. )
  743. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  744. [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
  745. )
  746. for (device, _), (
  747. (
  748. device_params_,
  749. device_grads_,
  750. device_exp_avgs_,
  751. device_exp_avg_sqs_,
  752. device_max_exp_avg_sqs,
  753. device_state_steps_,
  754. ),
  755. _,
  756. ) in grouped_tensors.items():
  757. device_params = cast(list[Tensor], device_params_)
  758. device_grads = cast(list[Tensor], device_grads_)
  759. device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
  760. device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
  761. device_state_steps = cast(list[Tensor], device_state_steps_)
  762. device_grad_scale, device_found_inf = None, None
  763. if grad_scale is not None:
  764. device_grad_scale = grad_scale_dict.setdefault(
  765. device, grad_scale.to(device, non_blocking=True)
  766. )
  767. if found_inf is not None:
  768. device_found_inf = found_inf_dict.setdefault(
  769. device, found_inf.to(device, non_blocking=True)
  770. )
  771. if lr_dict is not None and device not in lr_dict:
  772. lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr]
  773. lr = lr_dict[device]
  774. torch._foreach_add_(device_state_steps, 1)
  775. func = torch._fused_adam_ if not decoupled_weight_decay else torch._fused_adamw_
  776. # pyrefly: ignore [no-matching-overload]
  777. func(
  778. device_params,
  779. device_grads,
  780. device_exp_avgs,
  781. device_exp_avg_sqs,
  782. device_max_exp_avg_sqs, # type: ignore[arg-type]
  783. device_state_steps,
  784. amsgrad=amsgrad,
  785. lr=lr, # type: ignore[arg-type]
  786. beta1=beta1,
  787. beta2=beta2,
  788. weight_decay=weight_decay,
  789. eps=eps,
  790. maximize=maximize,
  791. grad_scale=device_grad_scale,
  792. found_inf=device_found_inf,
  793. )
  794. if device_found_inf is not None:
  795. torch._foreach_sub_(
  796. device_state_steps, [device_found_inf] * len(device_state_steps)
  797. )
  798. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
  799. def adam(
  800. params: list[Tensor],
  801. grads: list[Tensor],
  802. exp_avgs: list[Tensor],
  803. exp_avg_sqs: list[Tensor],
  804. max_exp_avg_sqs: list[Tensor],
  805. state_steps: list[Tensor],
  806. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  807. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  808. foreach: bool | None = None,
  809. capturable: bool = False,
  810. differentiable: bool = False,
  811. fused: bool | None = None,
  812. grad_scale: Tensor | None = None,
  813. found_inf: Tensor | None = None,
  814. has_complex: bool = False,
  815. decoupled_weight_decay: bool = False,
  816. *,
  817. amsgrad: bool,
  818. beta1: float | Tensor,
  819. beta2: float | Tensor,
  820. lr: float | Tensor,
  821. weight_decay: float,
  822. eps: float,
  823. maximize: bool,
  824. ) -> None:
  825. r"""Functional API that performs Adam algorithm computation.
  826. See :class:`~torch.optim.Adam` for details.
  827. """
  828. # Respect when the user inputs False/True for foreach or fused. We only want to change
  829. # the default when neither have been user-specified. Note that we default to foreach
  830. # and pass False to use_fused. This is not a mistake--we want to give the fused impl
  831. # bake-in time before making it the default, even if it is typically faster.
  832. if fused is None and foreach is None:
  833. _, foreach = _default_to_fused_or_foreach(
  834. params, differentiable, use_fused=False
  835. )
  836. # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
  837. if foreach and isinstance(lr, Tensor) and not capturable:
  838. foreach = False
  839. if fused is None:
  840. fused = False
  841. if foreach is None:
  842. foreach = False
  843. # this check is slow during compilation, so we skip it
  844. # if it's strictly needed we can add this check back in dynamo
  845. if not torch.compiler.is_compiling() and not all(
  846. isinstance(t, torch.Tensor) for t in state_steps
  847. ):
  848. raise RuntimeError(
  849. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  850. )
  851. if foreach and torch.jit.is_scripting():
  852. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  853. if fused and torch.jit.is_scripting():
  854. raise RuntimeError("torch.jit.script not supported with fused optimizers")
  855. if fused and not torch.jit.is_scripting():
  856. func = _fused_adam
  857. elif foreach and not torch.jit.is_scripting():
  858. func = _multi_tensor_adam
  859. else:
  860. func = _single_tensor_adam
  861. func(
  862. params,
  863. grads,
  864. exp_avgs,
  865. exp_avg_sqs,
  866. max_exp_avg_sqs,
  867. state_steps,
  868. amsgrad=amsgrad,
  869. has_complex=has_complex,
  870. beta1=beta1,
  871. beta2=beta2,
  872. lr=lr,
  873. weight_decay=weight_decay,
  874. eps=eps,
  875. maximize=maximize,
  876. capturable=capturable,
  877. differentiable=differentiable,
  878. grad_scale=grad_scale,
  879. found_inf=found_inf,
  880. decoupled_weight_decay=decoupled_weight_decay,
  881. )