| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991 |
- # mypy: allow-untyped-defs
- from typing import cast
- import torch
- from torch import Tensor
- from .optimizer import (
- _capturable_doc,
- _default_to_fused_or_foreach,
- _device_dtype_check_for_fused,
- _differentiable_doc,
- _disable_dynamo_if_unsupported,
- _foreach_doc,
- _fused_doc,
- _get_capturable_supported_devices,
- _get_scalar_dtype,
- _get_value,
- _maximize_doc,
- _params_doc,
- _stack_if_compiling,
- _to_scalar,
- _use_grad_for_differentiable,
- _view_as_real,
- DeviceDict,
- DeviceDtypeDict,
- Optimizer,
- ParamsT,
- )
- __all__ = ["Adam", "adam"]
- class Adam(Optimizer):
- def __init__(
- self,
- params: ParamsT,
- lr: float | Tensor = 1e-3,
- betas: tuple[float | Tensor, float | Tensor] = (0.9, 0.999),
- eps: float = 1e-8,
- weight_decay: float = 0,
- amsgrad: bool = False,
- *,
- foreach: bool | None = None,
- maximize: bool = False,
- capturable: bool = False,
- differentiable: bool = False,
- fused: bool | None = None,
- decoupled_weight_decay: bool = False,
- ) -> None:
- if isinstance(lr, Tensor):
- if foreach and not capturable:
- raise ValueError(
- "lr as a Tensor is not supported for capturable=False and foreach=True"
- )
- if lr.numel() != 1:
- raise ValueError("Tensor lr must be 1-element")
- if not 0.0 <= lr:
- raise ValueError(f"Invalid learning rate: {lr}")
- if not 0.0 <= eps:
- raise ValueError(f"Invalid epsilon value: {eps}")
- if not 0.0 <= betas[0] < 1.0:
- raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
- if not 0.0 <= betas[1] < 1.0:
- raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
- if not 0.0 <= weight_decay:
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
- if not (
- (isinstance(betas[0], float) and isinstance(betas[1], float))
- or (isinstance(betas[0], Tensor) and isinstance(betas[1], Tensor))
- ):
- raise ValueError("betas must be either both floats or both Tensors")
- if isinstance(betas[0], Tensor):
- if not capturable and foreach:
- raise ValueError(
- "betas[0] as a Tensor is not supported for capturable=False and foreach=True"
- )
- if betas[0].numel() != 1:
- raise ValueError("Tensor betas[0] must be 1-element")
- if isinstance(betas[1], Tensor):
- if not capturable and foreach:
- raise ValueError(
- "betas[1] as a Tensor is not supported for capturable=False and foreach=True"
- )
- if betas[1].numel() != 1:
- raise ValueError("Tensor betas[1] must be 1-element")
- betas = tuple(map(_to_scalar, betas))
- defaults = {
- "lr": lr,
- "betas": betas,
- "eps": eps,
- "weight_decay": weight_decay,
- "amsgrad": amsgrad,
- "maximize": maximize,
- "foreach": foreach,
- "capturable": capturable,
- "differentiable": differentiable,
- "fused": fused,
- "decoupled_weight_decay": decoupled_weight_decay,
- }
- super().__init__(params, defaults)
- if fused:
- if differentiable:
- raise RuntimeError("`fused` does not support `differentiable`")
- self._step_supports_amp_scaling = True
- # TODO(crcrpar): [low prec params & their higher prec copy]
- # Support AMP with FP16/BF16 model params which would need
- # higher prec copy of params to do update math in higher prec to
- # alleviate the loss of information.
- if foreach:
- raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
- def __setstate__(self, state):
- super().__setstate__(state)
- for group in self.param_groups:
- group.setdefault("amsgrad", False)
- group.setdefault("maximize", False)
- group.setdefault("foreach", None)
- group.setdefault("capturable", False)
- group.setdefault("differentiable", False)
- group.setdefault("decoupled_weight_decay", False)
- fused = group.setdefault("fused", None)
- for p in group["params"]:
- p_state = self.state.get(p, [])
- if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
- step_val = float(p_state["step"])
- p_state["step"] = (
- torch.tensor(
- step_val,
- dtype=_get_scalar_dtype(is_fused=fused),
- device=p.device,
- )
- if group["capturable"] or group["fused"]
- else torch.tensor(step_val, dtype=_get_scalar_dtype())
- )
- def _init_group(
- self,
- group,
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- ):
- has_complex = False
- for p in group["params"]:
- if p.grad is not None:
- has_complex |= torch.is_complex(p)
- params_with_grad.append(p)
- if p.grad.is_sparse:
- raise RuntimeError(
- "Adam does not support sparse gradients, please consider SparseAdam instead"
- )
- grads.append(p.grad)
- state = self.state[p]
- # Lazy state initialization
- if len(state) == 0:
- if group["fused"]:
- _device_dtype_check_for_fused(p)
- # note(crcrpar): [special device hosting for step]
- # Deliberately host `step` on CPU if both capturable and fused are off.
- # This is because kernel launches are costly on CUDA and XLA.
- state["step"] = (
- torch.zeros(
- (),
- dtype=_get_scalar_dtype(is_fused=group["fused"]),
- device=p.device,
- )
- if group["capturable"] or group["fused"]
- else torch.tensor(0.0, dtype=_get_scalar_dtype())
- )
- # Exponential moving average of gradient values
- state["exp_avg"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
- # Exponential moving average of squared gradient values
- state["exp_avg_sq"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
- if group["amsgrad"]:
- # Maintains max of all exp. moving avg. of sq. grad. values
- state["max_exp_avg_sq"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
- exp_avgs.append(state["exp_avg"])
- exp_avg_sqs.append(state["exp_avg_sq"])
- if group["amsgrad"]:
- max_exp_avg_sqs.append(state["max_exp_avg_sq"])
- if group["differentiable"] and state["step"].requires_grad:
- raise RuntimeError(
- "`requires_grad` is not supported for `step` in differentiable mode"
- )
- # Foreach without capturable does not support a tensor lr
- if (
- group["foreach"]
- and torch.is_tensor(group["lr"])
- and not group["capturable"]
- ):
- raise RuntimeError(
- "lr as a Tensor is not supported for capturable=False and foreach=True"
- )
- state_steps.append(state["step"])
- return has_complex
- @_use_grad_for_differentiable
- def step(self, closure=None):
- """Perform a single optimization step.
- Args:
- closure (Callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- self._accelerator_graph_capture_health_check()
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
- for group in self.param_groups:
- params_with_grad: list[Tensor] = []
- grads: list[Tensor] = []
- exp_avgs: list[Tensor] = []
- exp_avg_sqs: list[Tensor] = []
- max_exp_avg_sqs: list[Tensor] = []
- state_steps: list[Tensor] = []
- beta1, beta2 = group["betas"]
- has_complex = self._init_group(
- group,
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- )
- adam(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- amsgrad=group["amsgrad"],
- has_complex=has_complex,
- beta1=beta1,
- beta2=beta2,
- lr=group["lr"],
- weight_decay=group["weight_decay"],
- eps=group["eps"],
- maximize=group["maximize"],
- foreach=group["foreach"],
- capturable=group["capturable"],
- differentiable=group["differentiable"],
- fused=group["fused"],
- grad_scale=getattr(self, "grad_scale", None),
- found_inf=getattr(self, "found_inf", None),
- decoupled_weight_decay=group["decoupled_weight_decay"],
- )
- return loss
- Adam.__doc__ = (
- r"""Implements Adam algorithm.
- .. math::
- \begin{aligned}
- &\rule{110mm}{0.4pt} \\
- &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
- \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
- &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
- \:\textit{maximize}, \: \epsilon \text{ (epsilon)} \\
- &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
- v_0\leftarrow 0 \text{ (second moment)},\: v_0^{max}\leftarrow 0 \\[-1.ex]
- &\rule{110mm}{0.4pt} \\
- &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
- &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
- &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
- &\hspace{5mm}\textbf{else} \\
- &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
- &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
- &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
- &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
- &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
- &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
- &\hspace{5mm}\textbf{if} \: amsgrad \\
- &\hspace{10mm} v_t^{max} \leftarrow \mathrm{max}(v_{t-1}^{max},v_t) \\
- &\hspace{10mm}\widehat{v_t} \leftarrow v_t^{max}/\big(1-\beta_2^t \big) \\
- &\hspace{5mm}\textbf{else} \\
- &\hspace{10mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
- &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
- \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
- &\rule{110mm}{0.4pt} \\[-1.ex]
- &\bf{return} \: \theta_t \\[-1.ex]
- &\rule{110mm}{0.4pt} \\[-1.ex]
- \end{aligned}
- For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
- """
- + rf"""
- Args:
- {_params_doc}
- lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
- is not yet supported for all our implementations. Please use a float
- LR if you are not also specifying fused=True or capturable=True.
- betas (tuple[Union[float, Tensor], Union[float, Tensor]], optional):
- coefficients used for computing running averages of gradient and
- its square. If a tensor is provided, must be 1-element. (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- decoupled_weight_decay (bool, optional): if True, this optimizer is
- equivalent to AdamW and the algorithm will not accumulate weight
- decay in the momentum nor variance. (default: False)
- amsgrad (bool, optional): whether to use the AMSGrad variant of this
- algorithm from the paper `On the Convergence of Adam and Beyond`_
- (default: False)
- {_foreach_doc}
- {_maximize_doc}
- {_capturable_doc}
- {_differentiable_doc}
- {_fused_doc}
- .. Note::
- A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
- .. _Adam\: A Method for Stochastic Optimization:
- https://arxiv.org/abs/1412.6980
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
- """
- )
- def _single_tensor_adam(
- params: list[Tensor],
- grads: list[Tensor],
- exp_avgs: list[Tensor],
- exp_avg_sqs: list[Tensor],
- max_exp_avg_sqs: list[Tensor],
- state_steps: list[Tensor],
- grad_scale: Tensor | None,
- found_inf: Tensor | None,
- *,
- amsgrad: bool,
- has_complex: bool,
- beta1: float | Tensor,
- beta2: float | Tensor,
- lr: float | Tensor,
- weight_decay: float,
- eps: float,
- maximize: bool,
- capturable: bool,
- differentiable: bool,
- decoupled_weight_decay: bool,
- ) -> None:
- if grad_scale is not None or found_inf is not None:
- raise AssertionError("Expected grad_scale and found_inf to be None")
- if torch.jit.is_scripting():
- # this assert is due to JIT being dumb and not realizing that the ops below
- # have overloads to handle both float and Tensor lrs, so we just assert it's
- # a float since most people using JIT are using floats
- if not isinstance(lr, float):
- raise AssertionError(f"Expected lr to be a float, but got {type(lr)}")
- if not isinstance(beta1, float):
- raise AssertionError(f"Expected beta1 to be a float, but got {type(beta1)}")
- if not isinstance(beta2, float):
- raise AssertionError(f"Expected beta2 to be a float, but got {type(beta2)}")
- else:
- lr = _to_scalar(lr)
- beta1 = _to_scalar(beta1)
- beta2 = _to_scalar(beta2)
- # We only shuffle around the beta when it is a Tensor, otherwise, we prefer
- # treating it as a scalar.
- # Note: ensure type declaration is under conditional check for isinstance
- # or else torchscript will get cranky about the DeviceDict type.
- if isinstance(beta1, Tensor):
- beta1_dict: DeviceDtypeDict | None = {(beta1.device, beta1.dtype): beta1}
- else:
- beta1_dict = None
- for i, param in enumerate(params):
- grad = grads[i] if not maximize else -grads[i]
- exp_avg = exp_avgs[i]
- exp_avg_sq = exp_avg_sqs[i]
- step_t = state_steps[i]
- # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
- if not torch.compiler.is_compiling() and capturable:
- capturable_supported_devices = _get_capturable_supported_devices()
- if not (
- param.device.type == step_t.device.type
- and param.device.type in capturable_supported_devices
- ):
- raise AssertionError(
- f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
- )
- # update step
- step_t += 1
- if weight_decay != 0:
- if decoupled_weight_decay:
- # Perform stepweight decay
- param.mul_(1 - lr * weight_decay)
- else:
- # Nested if is necessary to bypass jitscript rules
- if differentiable and isinstance(weight_decay, Tensor):
- if weight_decay.requires_grad:
- grad = grad.addcmul_(param.clone(), weight_decay)
- else:
- # pyrefly: ignore [bad-argument-type]
- grad = grad.add(param, alpha=weight_decay)
- else:
- grad = grad.add(param, alpha=weight_decay)
- if torch.is_complex(param):
- grad = torch.view_as_real(grad)
- exp_avg = torch.view_as_real(exp_avg)
- exp_avg_sq = torch.view_as_real(exp_avg_sq)
- if amsgrad:
- max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
- param = torch.view_as_real(param)
- device = param.device
- if beta1_dict is not None:
- dtype = param.dtype # type: ignore[union-attr]
- # cast to workaround https://github.com/pytorch/pytorch/issues/140601
- key = (device, dtype)
- if key not in beta1_dict:
- beta1_dict[key] = beta1.to( # type: ignore[union-attr]
- device=device, dtype=dtype, non_blocking=True
- )
- device_beta1: float | Tensor = beta1_dict[key]
- else:
- device_beta1 = beta1
- # Decay the first and second moment running average coefficient
- exp_avg.lerp_(grad, 1 - device_beta1)
- # Nested if is necessary to bypass jitscript rules
- if differentiable and isinstance(beta2, Tensor):
- if beta2.requires_grad:
- # Using lerp to only use 2 operations bc addcmul's value cannot be a tensor
- # Showing equivalence of differentiable path and nondifferentiable path
- # expavg * b2 + grad^2 * (1-b2)
- # add expavg * (1-b2) - expavg * (1-b2) = 0
- # expavg * b2 + expavg * (1-b2) - expavg * (1-b2) + grad^2 * (1-b2)
- # expavg - expavg * (1-b2) + grad^2 * (1-b2)
- # expavg + (grad^2 - expavg) * (1-b2)
- # expavg.lerp(grad^2, 1-beta2)
- exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
- else:
- exp_avg_sq.mul_(beta2).addcmul_(
- grad, grad, value=cast(float, 1 - beta2)
- )
- else:
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # type: ignore[arg-type]
- if capturable or differentiable:
- step = step_t
- # Nested if is necessary to bypass jitscript rules
- if differentiable and isinstance(beta1, Tensor):
- if beta1.requires_grad:
- bias_correction1 = 1 - beta1 ** step.clone()
- else:
- bias_correction1 = 1 - beta1**step
- else:
- bias_correction1 = 1 - beta1**step
- # Nested if is necessary to bypass jitscript rules
- if differentiable and isinstance(beta2, Tensor):
- if beta2.requires_grad:
- bias_correction2 = 1 - beta2 ** step.clone()
- else:
- bias_correction2 = 1 - beta2**step
- else:
- bias_correction2 = 1 - beta2**step
- step_size = lr / bias_correction1
- step_size_neg = step_size.neg()
- bias_correction2_sqrt = bias_correction2.sqrt()
- if amsgrad:
- # Maintains the maximum of all 2nd moment running avg. till now
- if differentiable:
- max_exp_avg_sq = max_exp_avg_sqs[i].clone()
- else:
- max_exp_avg_sq = max_exp_avg_sqs[i]
- max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
- # Uses the max. for normalizing running avg. of gradient
- # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
- # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
- denom = (
- max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
- ).add_(eps / step_size_neg)
- else:
- denom = (
- exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
- ).add_(eps / step_size_neg)
- if differentiable:
- param.addcdiv_(exp_avg.clone(), denom)
- else:
- param.addcdiv_(exp_avg, denom)
- else:
- step = _get_value(step_t)
- bias_correction1 = 1 - beta1**step
- bias_correction2 = 1 - beta2**step
- step_size = lr / bias_correction1
- bias_correction2_sqrt = bias_correction2**0.5
- if amsgrad:
- # Maintains the maximum of all 2nd moment running avg. till now
- torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
- # Use the max. for normalizing running avg. of gradient
- denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
- else:
- denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
- param.addcdiv_(exp_avg, denom, value=-step_size) # type: ignore[arg-type]
- # Lastly, switch back to complex view
- if amsgrad and torch.is_complex(params[i]):
- max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
- def _multi_tensor_adam(
- params: list[Tensor],
- grads: list[Tensor],
- exp_avgs: list[Tensor],
- exp_avg_sqs: list[Tensor],
- max_exp_avg_sqs: list[Tensor],
- state_steps: list[Tensor],
- grad_scale: Tensor | None,
- found_inf: Tensor | None,
- *,
- amsgrad: bool,
- has_complex: bool,
- beta1: float | Tensor,
- beta2: float | Tensor,
- lr: float | Tensor,
- weight_decay: float,
- eps: float,
- maximize: bool,
- capturable: bool,
- differentiable: bool,
- decoupled_weight_decay: bool,
- ) -> None:
- if len(params) == 0:
- return
- if isinstance(lr, Tensor):
- if not capturable:
- raise RuntimeError(
- "lr as a Tensor is not supported for capturable=False and foreach=True"
- )
- if lr.numel() != 1:
- raise ValueError("Tensor lr must be 1-element")
- if isinstance(beta1, Tensor):
- if not capturable:
- raise ValueError(
- "beta1 as a Tensor is not supported for capturable=False and foreach=True"
- )
- if beta1.numel() != 1:
- raise ValueError("Tensor beta1 must be 1-element")
- if isinstance(beta2, Tensor):
- if not capturable:
- raise ValueError(
- "beta2 as a Tensor is not supported for capturable=False and foreach=True"
- )
- if beta2.numel() != 1:
- raise ValueError("Tensor beta2 must be 1-element")
- # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
- if not torch.compiler.is_compiling() and capturable:
- capturable_supported_devices = _get_capturable_supported_devices(
- supports_xla=False
- )
- if not all(
- p.device.type == step.device.type
- and p.device.type in capturable_supported_devices
- for p, step in zip(params, state_steps, strict=True)
- ):
- raise AssertionError(
- f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
- )
- if grad_scale is not None or found_inf is not None:
- raise AssertionError("Expected grad_scale and found_inf to be None")
- if differentiable:
- raise AssertionError("_foreach ops don't support autograd")
- lr = _to_scalar(lr)
- beta1 = _to_scalar(beta1)
- beta2 = _to_scalar(beta2)
- grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
- [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
- )
- # We only shuffle around the beta when it is a Tensor and on CUDA, otherwise, we prefer
- # treating it as a scalar.
- beta1_dict: DeviceDict | None = ( # type: ignore[attr-defined]
- {beta1.device: beta1}
- if isinstance(beta1, Tensor) and str(beta1.device) != "cpu"
- else None
- )
- for (
- device_params_,
- device_grads_,
- device_exp_avgs_,
- device_exp_avg_sqs_,
- device_max_exp_avg_sqs_,
- device_state_steps_,
- ), _ in grouped_tensors.values():
- device_params = cast(list[Tensor], device_params_)
- device_grads = cast(list[Tensor], device_grads_)
- device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
- device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
- device_state_steps = cast(list[Tensor], device_state_steps_)
- device = device_params[0].device
- if beta1_dict is not None and device not in beta1_dict:
- beta1_dict[device] = beta1.to(device=device, non_blocking=True) # type: ignore[union-attr, attr-defined]
- device_beta1 = beta1_dict[device] if beta1_dict else beta1
- # Handle complex parameters
- if has_complex:
- if amsgrad:
- device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
- _view_as_real(
- device_params,
- device_grads,
- device_exp_avgs,
- device_exp_avg_sqs,
- device_max_exp_avg_sqs,
- )
- else:
- _view_as_real(
- device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
- )
- if maximize:
- device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
- # Update steps
- # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
- # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
- # wrapped it once now. The alpha is required to assure we go to the right overload.
- if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
- torch._foreach_add_(
- device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
- )
- else:
- torch._foreach_add_(device_state_steps, 1)
- if weight_decay != 0:
- if decoupled_weight_decay:
- # Perform stepweight decay
- torch._foreach_mul_(device_params, 1 - lr * weight_decay)
- else:
- # Reuse the intermediate memory (device_grads) already allocated for maximize
- if maximize:
- torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
- else:
- device_grads = torch._foreach_add( # type: ignore[assignment]
- device_grads, device_params, alpha=weight_decay
- )
- # Decay the first and second moment running average coefficient
- # Use device beta1 if beta1 is a tensor to ensure all
- # tensors are on the same device
- torch._foreach_lerp_(
- device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
- )
- torch._foreach_mul_(device_exp_avg_sqs, beta2)
- # Due to the strictness of the _foreach_addcmul API, we can't have a single
- # tensor scalar as the scalar arg (only python number is supported there)
- # as a result, separate out the value mul
- # Filed https://github.com/pytorch/pytorch/issues/139795
- if isinstance(beta2, torch.Tensor):
- scaled_device_grads = torch._foreach_mul(device_grads, 1 - beta2) # type: ignore[assignment]
- value = 1.0
- else:
- scaled_device_grads = device_grads # type: ignore[assignment]
- value = 1 - beta2
- torch._foreach_addcmul_(
- device_exp_avg_sqs, scaled_device_grads, device_grads, value
- )
- # Delete the local intermediate(s) since they won't be used anymore to save on peak memory
- del device_grads
- del scaled_device_grads
- bias_correction1: tuple[Tensor, ...] | list[Tensor]
- bias_correction2: tuple[Tensor, ...] | list[Tensor]
- bias_correction2_sqrt: tuple[Tensor, ...] | list[Tensor]
- if capturable:
- bias_correction1 = torch._foreach_pow(beta1, device_state_steps) # type: ignore[arg-type]
- bias_correction2 = torch._foreach_pow(beta2, device_state_steps) # type: ignore[arg-type]
- # foreach_sub doesn't allow a scalar as the first arg
- torch._foreach_sub_(bias_correction1, 1)
- torch._foreach_sub_(bias_correction2, 1)
- # we do not negate bias_correction1 as it'll need to be negated later anyway
- torch._foreach_neg_(bias_correction2)
- # foreach_div doesn't allow a scalar as the first arg
- torch._foreach_div_(bias_correction1, lr)
- torch._foreach_reciprocal_(bias_correction1)
- torch._foreach_sqrt_(bias_correction2)
- # Re-assign for clarity as we maintain minimal intermediates: we'll have
- # step_size = - lr / (1 - beta1 ^ t) where t = num_steps
- # bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
- step_size = bias_correction1
- bias_correction2_sqrt = bias_correction2
- if amsgrad:
- device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
- # Maintains the maximum of all 2nd moment running avg. till now
- torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment]
- # Set intermediate to the max. for normalizing running avg. of gradient when amsgrad
- exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
- else:
- exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
- torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
- torch._foreach_add_(exp_avg_sq_sqrt, eps)
- torch._foreach_div_(exp_avg_sq_sqrt, step_size)
- # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
- torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
- else:
- bias_correction1 = [
- 1 - beta1 ** _get_value(step) for step in device_state_steps
- ]
- bias_correction2 = [
- 1 - beta2 ** _get_value(step) for step in device_state_steps
- ]
- step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
- bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type]
- if amsgrad:
- device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
- # Maintains the maximum of all 2nd moment running avg. till now
- torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
- # Use the max. for normalizing running avg. of gradient
- exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
- else:
- exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
- torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
- torch._foreach_add_(exp_avg_sq_sqrt, eps)
- torch._foreach_addcdiv_(
- device_params,
- device_exp_avgs,
- exp_avg_sq_sqrt,
- step_size, # type: ignore[arg-type]
- )
- def _fused_adam(
- params: list[Tensor],
- grads: list[Tensor],
- exp_avgs: list[Tensor],
- exp_avg_sqs: list[Tensor],
- max_exp_avg_sqs: list[Tensor],
- state_steps: list[Tensor],
- grad_scale: Tensor | None,
- found_inf: Tensor | None,
- *,
- amsgrad: bool,
- has_complex: bool, # Needed for consistency.
- beta1: float | Tensor,
- beta2: float | Tensor,
- lr: float | Tensor,
- weight_decay: float,
- eps: float,
- maximize: bool,
- capturable: bool, # Needed for consistency.
- differentiable: bool,
- decoupled_weight_decay: bool,
- ) -> None:
- if not params:
- return
- if differentiable:
- raise RuntimeError("Adam with fused=True does not support differentiable=True")
- beta1 = _to_scalar(beta1)
- beta2 = _to_scalar(beta2)
- grad_scale_dict: DeviceDict = (
- {grad_scale.device: grad_scale} if grad_scale is not None else {}
- )
- found_inf_dict: DeviceDict = (
- {found_inf.device: found_inf} if found_inf is not None else {}
- )
- # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
- # treating it as a scalar.
- lr_dict: DeviceDict | None = (
- {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
- )
- grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
- [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
- )
- for (device, _), (
- (
- device_params_,
- device_grads_,
- device_exp_avgs_,
- device_exp_avg_sqs_,
- device_max_exp_avg_sqs,
- device_state_steps_,
- ),
- _,
- ) in grouped_tensors.items():
- device_params = cast(list[Tensor], device_params_)
- device_grads = cast(list[Tensor], device_grads_)
- device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
- device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
- device_state_steps = cast(list[Tensor], device_state_steps_)
- device_grad_scale, device_found_inf = None, None
- if grad_scale is not None:
- device_grad_scale = grad_scale_dict.setdefault(
- device, grad_scale.to(device, non_blocking=True)
- )
- if found_inf is not None:
- device_found_inf = found_inf_dict.setdefault(
- device, found_inf.to(device, non_blocking=True)
- )
- if lr_dict is not None and device not in lr_dict:
- lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr]
- lr = lr_dict[device]
- torch._foreach_add_(device_state_steps, 1)
- func = torch._fused_adam_ if not decoupled_weight_decay else torch._fused_adamw_
- # pyrefly: ignore [no-matching-overload]
- func(
- device_params,
- device_grads,
- device_exp_avgs,
- device_exp_avg_sqs,
- device_max_exp_avg_sqs, # type: ignore[arg-type]
- device_state_steps,
- amsgrad=amsgrad,
- lr=lr, # type: ignore[arg-type]
- beta1=beta1,
- beta2=beta2,
- weight_decay=weight_decay,
- eps=eps,
- maximize=maximize,
- grad_scale=device_grad_scale,
- found_inf=device_found_inf,
- )
- if device_found_inf is not None:
- torch._foreach_sub_(
- device_state_steps, [device_found_inf] * len(device_state_steps)
- )
- @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
- def adam(
- params: list[Tensor],
- grads: list[Tensor],
- exp_avgs: list[Tensor],
- exp_avg_sqs: list[Tensor],
- max_exp_avg_sqs: list[Tensor],
- state_steps: list[Tensor],
- # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
- # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
- foreach: bool | None = None,
- capturable: bool = False,
- differentiable: bool = False,
- fused: bool | None = None,
- grad_scale: Tensor | None = None,
- found_inf: Tensor | None = None,
- has_complex: bool = False,
- decoupled_weight_decay: bool = False,
- *,
- amsgrad: bool,
- beta1: float | Tensor,
- beta2: float | Tensor,
- lr: float | Tensor,
- weight_decay: float,
- eps: float,
- maximize: bool,
- ) -> None:
- r"""Functional API that performs Adam algorithm computation.
- See :class:`~torch.optim.Adam` for details.
- """
- # Respect when the user inputs False/True for foreach or fused. We only want to change
- # the default when neither have been user-specified. Note that we default to foreach
- # and pass False to use_fused. This is not a mistake--we want to give the fused impl
- # bake-in time before making it the default, even if it is typically faster.
- if fused is None and foreach is None:
- _, foreach = _default_to_fused_or_foreach(
- params, differentiable, use_fused=False
- )
- # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
- if foreach and isinstance(lr, Tensor) and not capturable:
- foreach = False
- if fused is None:
- fused = False
- if foreach is None:
- foreach = False
- # this check is slow during compilation, so we skip it
- # if it's strictly needed we can add this check back in dynamo
- if not torch.compiler.is_compiling() and not all(
- isinstance(t, torch.Tensor) for t in state_steps
- ):
- raise RuntimeError(
- "API has changed, `state_steps` argument must contain a list of singleton tensors"
- )
- if foreach and torch.jit.is_scripting():
- raise RuntimeError("torch.jit.script not supported with foreach optimizers")
- if fused and torch.jit.is_scripting():
- raise RuntimeError("torch.jit.script not supported with fused optimizers")
- if fused and not torch.jit.is_scripting():
- func = _fused_adam
- elif foreach and not torch.jit.is_scripting():
- func = _multi_tensor_adam
- else:
- func = _single_tensor_adam
- func(
- params,
- grads,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- amsgrad=amsgrad,
- has_complex=has_complex,
- beta1=beta1,
- beta2=beta2,
- lr=lr,
- weight_decay=weight_decay,
- eps=eps,
- maximize=maximize,
- capturable=capturable,
- differentiable=differentiable,
- grad_scale=grad_scale,
- found_inf=found_inf,
- decoupled_weight_decay=decoupled_weight_decay,
- )
|