| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360 |
- # mypy: allow-untyped-defs
- # mypy: disable-error-code=arg-type
- """Implementation of the Muon optimizer."""
- import math
- from collections.abc import MutableMapping
- import torch
- from torch import Tensor
- from .optimizer import (
- _disable_dynamo_if_unsupported,
- _params_doc,
- _to_scalar,
- Optimizer,
- ParamsT,
- )
- __all__ = ["Muon"]
- # Constants from Keller Jordan's Muon post: https://kellerjordan.github.io/posts/muon/
- # github permlink: https://github.com/KellerJordan/Muon/blob/f90a42b28e00b8d9d2d05865fe90d9f39abcbcbd/muon.py#L16
- EPS = 1e-7
- DEFAULT_A = 3.4445
- DEFAULT_B = -4.7750
- DEFAULT_C = 2.0315
- DEFAULT_NS_STEPS = 5
- def _zeropower_via_newtonschulz(
- grad: Tensor, ns_coefficients: tuple[float, float, float], ns_steps: int, eps: float
- ) -> Tensor:
- """
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
- performance at all relative to UV^T, where USV^T = G is the SVD.
- Implementation reference: https://github.com/KellerJordan/Muon/blob/master/muon.py
- with suggestions by @jxbz, @leloykun, and @YouJiacheng.
- """
- if ns_steps >= 100:
- raise ValueError(
- "Number of steps must be less than 100 for computational efficiency"
- )
- if len(grad.shape) != 2:
- raise ValueError("Input tensor gradient must be a 2D matrix")
- if len(ns_coefficients) != 3:
- raise ValueError("Coefficients must be a tuple of exactly 3 values")
- a, b, c = ns_coefficients
- ortho_grad = grad.bfloat16()
- if grad.size(0) > grad.size(1):
- ortho_grad = ortho_grad.T
- # Ensure spectral norm is at most 1
- ortho_grad.div_(ortho_grad.norm().clamp(min=eps))
- # Perform the NS iterations
- for _ in range(ns_steps):
- gram_matrix = ortho_grad @ ortho_grad.T
- gram_update = torch.addmm(
- gram_matrix, gram_matrix, gram_matrix, beta=b, alpha=c
- )
- ortho_grad = torch.addmm(ortho_grad, gram_update, ortho_grad, beta=a)
- if grad.size(0) > grad.size(1):
- ortho_grad = ortho_grad.T
- return ortho_grad
- def _adjust_lr(lr: float, adjust_lr_fn: str | None, param_shape: torch.Size) -> float:
- """Default learning rate adjustment used by Muon."""
- A, B = param_shape[:2]
- if adjust_lr_fn is None or adjust_lr_fn == "original":
- # pyrefly: ignore [no-matching-overload]
- adjusted_ratio = math.sqrt(max(1, A / B))
- elif adjust_lr_fn == "match_rms_adamw":
- adjusted_ratio = 0.2 * math.sqrt(max(A, B))
- else:
- adjusted_ratio = 1.0
- return lr * adjusted_ratio
- class Muon(Optimizer):
- def __init__(
- self,
- params: ParamsT,
- lr: float = 1e-3,
- weight_decay: float = 0.1,
- momentum: float = 0.95,
- nesterov: bool = True,
- ns_coefficients: tuple[float, float, float] = (DEFAULT_A, DEFAULT_B, DEFAULT_C),
- eps: float = EPS,
- ns_steps: int = DEFAULT_NS_STEPS,
- adjust_lr_fn: str | None = None,
- ) -> None:
- if isinstance(lr, Tensor) and lr.numel() != 1:
- raise ValueError("Tensor lr must be 1-element")
- if not 0.0 <= lr:
- raise ValueError(f"Learning rate should be >= 0 but is: {lr}")
- if not 0.0 <= momentum:
- raise ValueError(f"momentum should be >= 0 but is: {momentum}")
- if not 0.0 <= weight_decay:
- raise ValueError(f"weight decay should be >= 0 but is: {weight_decay}")
- if adjust_lr_fn is not None and adjust_lr_fn not in [
- "original",
- "match_rms_adamw",
- ]:
- raise ValueError(
- f"Adjust learning rate function {adjust_lr_fn} is not supported"
- )
- defaults = {
- "lr": lr,
- "weight_decay": weight_decay,
- "momentum": momentum,
- "nesterov": nesterov,
- "ns_coefficients": ns_coefficients,
- "eps": eps,
- "ns_steps": ns_steps,
- "adjust_lr_fn": adjust_lr_fn,
- }
- super().__init__(params, defaults)
- for group in self.param_groups:
- for p in group["params"]:
- if p.ndim != 2:
- raise ValueError(
- f"Muon only supports 2D parameters whereas we found a parameter with size: {p.size()}"
- )
- def _init_group(
- self,
- group: MutableMapping,
- params_with_grad: list[Tensor],
- grads: list[Tensor],
- muon_momentum_bufs: list[Tensor],
- ) -> bool:
- for p in group["params"]:
- if p.grad is None:
- continue
- if torch.is_complex(p):
- raise RuntimeError("Muon does not support complex parameters")
- if p.grad.is_sparse:
- raise RuntimeError("Muon does not support sparse gradients")
- params_with_grad.append(p)
- grads.append(p.grad)
- state = self.state[p]
- if "momentum_buffer" not in state:
- state["momentum_buffer"] = torch.zeros_like(
- p.grad, memory_format=torch.preserve_format
- )
- muon_momentum_bufs.append(state["momentum_buffer"])
- return False # has_complex
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step."""
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
- for group in self.param_groups:
- lr = group["lr"]
- weight_decay = group["weight_decay"]
- momentum = group["momentum"]
- params_with_grad: list[Tensor] = []
- grads: list[Tensor] = []
- muon_momentum_bufs: list[Tensor] = []
- has_complex = self._init_group(
- group,
- params_with_grad,
- grads,
- muon_momentum_bufs,
- )
- muon(
- params_with_grad,
- grads,
- muon_momentum_bufs,
- lr=lr,
- weight_decay=weight_decay,
- momentum=momentum,
- nesterov=group["nesterov"],
- ns_coefficients=group["ns_coefficients"],
- eps=group["eps"],
- ns_steps=group["ns_steps"],
- adjust_lr_fn=group["adjust_lr_fn"],
- has_complex=has_complex,
- )
- return loss
- Muon.__doc__ = (
- r"""Implements Muon algorithm.
- .. math::
- \begin{aligned}
- &\rule{110mm}{0.4pt} \\
- &\textbf{input} : \gamma \text{ (lr)},\ \lambda \text{ (weight decay)},\
- \mu \text{ (momentum)},\ \textit{nesterov}\in\{True,False\},\\
- &\hspace{13mm}(a,b,c)\ \text{ (NS coefficients)},\
- \varepsilon \text{ (epsilon)},\ k \text{ (NS steps)},\
- \theta_0 \text{ (params)},\ f(\theta) \text{ (objective)} \\
- &\textbf{initialize} : B_0 \leftarrow 0 \text{ (momentum buffer)} \\[-1.ex]
- &\rule{110mm}{0.4pt} \\
- &\textbf{for}\ t=1\ \textbf{to}\ \ldots\ \textbf{do} \\[0.25ex]
- &\hspace{5mm} g_t \leftarrow \nabla_{\theta} f_t(\theta_{t-1}) \\[0.25ex]
- &\hspace{5mm} B_t \leftarrow \mu B_{t-1} + g_t \\[0.25ex]
- &\hspace{5mm} \widetilde{B}_t \leftarrow
- \begin{cases}
- g_t + \mu B_t, & \text{if nesterov}=True \\
- B_t, & \text{if nesterov}=False
- \end{cases} \\[1.0ex]
- &\hspace{5mm} O_t \leftarrow \mathrm{NS}^{(a,b,c)}_{k}\!\big(\widetilde{B}_t;\ \varepsilon\big) \\[0.5ex]
- &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma\,\lambda\,\theta_{t-1}
- \quad\text{(decoupled weight decay)} \\[0.25ex]
- &\hspace{5mm} \gamma \leftarrow \mathrm{AdjustLR}\!\big(\gamma;\ \mathrm{shape}\!\big(\theta_t \big) \big) \\[0.25ex]
- &\hspace{5mm} \theta_t \leftarrow \theta_t - \gamma\, O_t \\
- &\rule{110mm}{0.4pt} \\[-1.ex]
- &\mathbf{return}\ \theta_t \\[-1.ex]
- &\rule{110mm}{0.4pt}s
- \end{aligned}
- Here, :math:`\mathrm{NS}^{(a,b,c)}_{k}(\cdot;\varepsilon)` denotes :math:`k` iterations of the
- Newton–Schulz orthogonalization operator parameterized by coefficients :math:`(a,b,c)`
- with numerical stabilization :math:`\varepsilon`.
- The purpose for :math:`\mathrm{AdjustLR}\!\big(\gamma;\ \mathrm{shape}\!\big(\theta_t \big) \big)`
- is to make the orthogonalized update have a consistent :math:`RMS` across rectangular matrices.
- Keller's original implementation scales the update by :math:`\sqrt{\max\!\left(1, \frac{A}{B}\right)}`,
- where :math:`A` and :math:`B` are dimension of the matrix being optimized.
- Moonshot's implementation also focuses on matching :math:`RMS` of AdamW. The adjustment is computed as:
- :math:`\gamma \leftarrow {0.2}\gamma\,\sqrt{\max\!\left({A}, {B}\right)}`
- The method is adopted from `Muon is Scalable for LLM Training`_. Research
- results show that with this adjustment Muon can directly reuse the learning rate
- and weight decay tuned for AdamW.
- We provide two options for the learning rate adjustment: "original", which follows Keller's
- implementation, and "match_rms_adamw", which refers to Moonshot's implementation. This gives users the
- flexibility to choose between the two. If `adjust_lr_fn` is not specified, the default is "original".
- For further details regarding the algorithm we refer to `Muon: An optimizer for hidden layers in neural networks`_
- and `Muon is Scalable for LLM Training`_.
- """
- + rf"""
- Args:
- {_params_doc}. Note that Muon is an optimizer for 2D parameters of neural network hidden layers. Other
- parameters, such as bias, and embedding, should be optimized by a standard method such as AdamW.
- lr (float, Tensor, optional): learning rate (default: 1e-3).
- weight_decay (float, optional): weight decay (L2 penalty). (default: 0.1)
- momentum (float, optional): momentum factor (default: 0.95)
- nesterov (bool, optional): enables Nesterov momentum. Only applicable
- when momentum is non-zero
- ns_coefficients (tuple of three floats, optional): coefficients \(a,b,c\) for the
- Newton–Schulz orthogonalization polynomial (default: ({DEFAULT_A}, {DEFAULT_B}, {DEFAULT_C}))
- eps (float, optional): term added to the denominator for numerical stability. (default: {EPS})
- ns_steps (int, optional): number of Newton–Schulz iteration steps. (default: {DEFAULT_NS_STEPS})
- adjust_lr_fn (str, optional): function to adjust learning rate. One of "original" and "match_rms_adamw".
- If not specified, we will default to use "original". (default: None)
- .. _Muon\: An optimizer for hidden layers in neural networks:
- https://kellerjordan.github.io/posts/muon/
- .. _Muon is Scalable for LLM Training:
- https://arxiv.org/pdf/2502.16982
- """
- )
- def _single_tensor_muon(
- params: list[Tensor],
- grads: list[Tensor],
- muon_momentum_bufs: list[Tensor],
- *,
- lr: float,
- weight_decay: float,
- momentum: float,
- nesterov: bool,
- ns_coefficients: tuple[float, float, float],
- ns_steps: int,
- eps: float,
- adjust_lr_fn: str | None,
- has_complex: bool,
- ) -> None:
- lr = _to_scalar(lr)
- if has_complex:
- raise ValueError("Complex parameters are not supported")
- for i, param in enumerate(params):
- grad = grads[i]
- if grad.ndim != 2:
- raise ValueError("Param gradient must be a 2D matrix")
- buf = muon_momentum_bufs[i]
- buf.lerp_(grad, 1 - momentum)
- update = grad.lerp(buf, momentum) if nesterov else buf
- update = _zeropower_via_newtonschulz(update, ns_coefficients, ns_steps, eps)
- adjusted_lr = _adjust_lr(lr, adjust_lr_fn, param.shape)
- param.mul_(1 - lr * weight_decay)
- param.add_(update, alpha=-adjusted_lr)
- @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_muon)
- def muon(
- params: list[Tensor],
- grads: list[Tensor],
- muon_momentum_bufs: list[Tensor],
- *,
- foreach: bool | None = None,
- lr: float,
- weight_decay: float,
- momentum: float,
- nesterov: bool,
- ns_coefficients: tuple[float, float, float],
- ns_steps: int,
- eps: float,
- adjust_lr_fn: str | None,
- has_complex: bool,
- ) -> None:
- r"""Functional API that performs Muon algorithm computation.
- See :class:`~torch.optim.Muon` for details.
- """
- if foreach is not None and foreach:
- raise RuntimeError("Foreach is not supported for Muon yet")
- func = _single_tensor_muon
- func(
- params,
- grads,
- muon_momentum_bufs,
- lr=lr,
- weight_decay=weight_decay,
- momentum=momentum,
- nesterov=nesterov,
- ns_coefficients=ns_coefficients,
- ns_steps=ns_steps,
- eps=eps,
- adjust_lr_fn=adjust_lr_fn,
- has_complex=has_complex,
- )
|