muon.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011
  1. """ Muon Optimizer
  2. Improved Muon optimizer implementation with flexible handling of high-dimensional tensors.
  3. Combines PyTorch-style structure with options for:
  4. - Batched spatial processing for convolutions in addition to flatten
  5. - Optional spatial normalization
  6. - Selectable coefficient presets
  7. - Automatic fallback to AdamW for 1D / scalar parameters (biases, norms, etc.) and optional fallback via param groups
  8. - AdaMuon (https://arxiv.org/abs/2507.11005)
  9. - mUP eps damping factor (https://arxiv.org/abs/2512.05620v1)
  10. TODO look into mUP LR scaling and independent weight-decay scale
  11. Based on implementation by Keller Jordan, see
  12. - https://github.com/KellerJordan/Muon/blob/master/muon.py
  13. - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
  14. - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py
  15. - https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py
  16. Hacked together by Ross Wightman
  17. """
  18. import logging
  19. import numbers
  20. from typing import List, Mapping, Optional, Sequence, Tuple, Union
  21. import torch
  22. try:
  23. from torch.distributed.tensor import DTensor
  24. has_dtensor = True
  25. except ImportError:
  26. has_dtensor = False
  27. from ._types import ParamsT
  28. from .adamw import adamw
  29. from .nadamw import nadamw
  30. _logger = logging.getLogger(__name__)
  31. # Constants from Keller Jordan's Muon
  32. MUON_EPS = 1e-7
  33. DEFAULT_NS_STEPS = 5
  34. _COEFFICIENTS = {
  35. "original": [
  36. # Keller Jordan's Muon https://kellerjordan.github.io/posts/muon/
  37. (3.4445, -4.7750, 2.0315),
  38. ],
  39. "quintic": [
  40. # https://leloykun.github.io/ponder/muon-opt-coeffs/#how-do-we-optimize-the-coefficients
  41. # From https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py#L44
  42. (4.0848, -6.8946, 2.9270),
  43. (3.9505, -6.3029, 2.6377),
  44. (3.7418, -5.5913, 2.3037),
  45. (2.8769, -3.1427, 1.2046),
  46. (2.8366, -3.0525, 1.2012),
  47. ],
  48. "polar_express": [
  49. # Polar Express https://arxiv.org/abs/2505.16932
  50. # From https://github.com/NoahAmsel/PolarExpress/tree/main with safety 1e-2
  51. (8.237312490495555, -23.157747414558198, 16.680568411445915),
  52. (4.082441999064835, -2.893047735332586, 0.5252849256975648),
  53. (3.9263479922546582, -2.8547468034765298, 0.5318022422894988),
  54. (3.2982187133085143, -2.424541981026706, 0.48632008358844075),
  55. (2.2970369434552573, -1.63662558125903, 0.4002628455953627),
  56. (1.8763805351440397, -1.2347896577722228, 0.35891887501668385),
  57. (1.8564423485617974, -1.2132449880935525, 0.3568003487825883),
  58. (1.8749994008682747, -1.2499988017229169, 0.3749994008546422),
  59. ],
  60. "polar_express_safer": [
  61. # from https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
  62. # w/ safety 2e-2
  63. (8.156554524902461, -22.48329292557795, 15.878769915207462),
  64. (4.0429299351667245, -2.808917465908704, 0.5000178451051299),
  65. (3.8916678022926563, -2.7724841532176825, 0.5060648178503389),
  66. (3.285753657755658, -2.3681294933425394, 0.46449024233003117),
  67. (2.3005307116270983, -1.6111665557258408, 0.3833374427545273),
  68. (1.8631210546382593, -1.2042160621002727, 0.3421879560523383),
  69. (1.8382572152247512, -1.1779263289537742, 0.3396513038637379),
  70. (1.8749999923301852, -1.2499999836060613, 0.374999991275876),
  71. ],
  72. }
  73. NSCoeff = Union[str, Tuple[float, float, float], List[Tuple[float, float, float]]]
  74. def scale_eps_for_ns(
  75. eps: float,
  76. shape: Tuple[int, ...],
  77. ) -> float:
  78. """Scale epsilon for Newton-Schulz based on matrix dimensions (μP-style).
  79. For μP compatibility, epsilon should scale as eps * sqrt(din/dout) to maintain
  80. consistent damping behavior across different model widths.
  81. Reference: https://arxiv.org/abs/2512.05620
  82. Args:
  83. eps: Base epsilon value
  84. shape: Shape of the matrix (out, in) or (batch, out, in)
  85. Returns:
  86. Scaled epsilon value
  87. """
  88. # Get din, dout from shape (handle both 2D and 3D batched)
  89. # FIXME TBD paper includes depth in the damping scale, e.g: eps * (din / dout) ** 0.5 / N
  90. dout, din = (shape[-2], shape[-1])
  91. return eps * (din / dout) ** 0.5
  92. def zeropower_via_newtonschulz(
  93. G: torch.Tensor,
  94. steps: int,
  95. coefficients: List[Tuple[float, float, float]],
  96. eps: float = MUON_EPS,
  97. safety_factor: float = 1.0,
  98. dtype: torch.dtype = torch.bfloat16,
  99. scale_eps: bool = False,
  100. ) -> torch.Tensor:
  101. """Newton-Schulz quintic iteration to compute the zeroth power / orthogonalization of gradient.
  102. Supports batched operation over leading dimensions.
  103. See
  104. - https://github.com/KellerJordan/Muon/blob/master/muon.py
  105. - https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py
  106. - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
  107. Args:
  108. G: Input gradient tensor of shape (m, n) or (batch, m, n)
  109. steps: Number of Newton-Schulz iterations
  110. coefficients: Coefficients (a, b, c) for the iteration
  111. eps: Numerical stability epsilon for norm
  112. safety_factor: Multiplicative safety factor for norm (1.01 is common safety value in 'polar express' variants)
  113. dtype: Computation dtype
  114. scale_eps: If True, scale epsilon by sqrt(din/dout) for μP compatibility
  115. Returns:
  116. Orthogonalized tensor of same shape as G
  117. """
  118. assert G.ndim in (2, 3), f"Input must be 2D or 3D, got {G.ndim}D. Flatten batch dims first."
  119. num_cs = len(coefficients)
  120. assert num_cs >= 1 and len(coefficients[0]) == 3
  121. # match coefficients with # of steps, truncate or repeat last
  122. coeff_sequence = coefficients[:steps] if steps <= num_cs else \
  123. coefficients + [coefficients[-1]] * (steps - num_cs)
  124. # Scale epsilon by sqrt(din/dout) for μP compatibility if requested
  125. if scale_eps:
  126. eps = scale_eps_for_ns(eps, G.shape)
  127. X = G.to(dtype=dtype, copy=True)
  128. # Transpose if needed (operate on dimension with fewer elements)
  129. transposed = X.size(-2) > X.size(-1)
  130. if transposed:
  131. X = X.mT
  132. # Normalize spectral norm to at most 1
  133. if scale_eps:
  134. # more of a damping factor in this case, use add instead of clamp
  135. X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).add_(eps))
  136. else:
  137. X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_(min=eps))
  138. is_dtensor = has_dtensor and isinstance(G, DTensor)
  139. if is_dtensor:
  140. # Basic, DTensor-friendly Newton-Schulz
  141. for a, b, c in coeff_sequence:
  142. A = X @ X.mT
  143. B = b * A + c * (A @ A)
  144. X = a * X + (B @ X)
  145. else:
  146. # Fast prealloc/out= path
  147. # Batched vs unbatched fused MM
  148. mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm
  149. # Pre-allocate
  150. X = X.contiguous()
  151. A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
  152. B = torch.empty_like(A)
  153. C = torch.empty_like(X)
  154. # Perform Newton-Schulz iterations
  155. for a, b, c in coeff_sequence:
  156. mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT
  157. mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A
  158. mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X
  159. X, C = C, X # swap refs to avoid copy
  160. if transposed:
  161. X = X.mT
  162. return X
  163. def get_lr_scale(
  164. param_shape: torch.Size,
  165. adjust_lr_fn: str = "match_rms_adamw",
  166. ) -> float:
  167. """Adjust learning rate based on parameter shape for Muon.
  168. Args:
  169. param_shape: Shape of the parameter tensor
  170. adjust_lr_fn: Scaling function name
  171. - "original": sqrt(max(1, out/in)) - Original Muon impl
  172. - "match_rms_adamw": 0.2 * sqrt(max(out, in)) - Kimi scaling
  173. - "rms_to_rms": sqrt(out/in) - Scion/Bernstein scaling
  174. """
  175. out_chs, in_chs = (param_shape[-2], param_shape[-1]) if len(param_shape) > 1 else (1., 1.)
  176. if adjust_lr_fn == "original":
  177. # Original Muon impl (https://kellerjordan.github.io/posts/muon/)
  178. return max(1, out_chs / in_chs) ** 0.5
  179. elif adjust_lr_fn == "match_rms_adamw":
  180. # Kimi (https://arxiv.org/abs/2502.16982)
  181. return 0.2 * max(out_chs, in_chs) ** 0.5
  182. elif adjust_lr_fn == "rms_to_rms":
  183. # Scion (https://arxiv.org/abs/2502.07529, https://github.com/LIONS-EPFL/scion)
  184. # Bernstein et al. (https://jeremybernste.in/writing/deriving-muon)
  185. return (out_chs / in_chs) ** 0.5
  186. else:
  187. assert False, f'Invalid scaling function "{adjust_lr_fn}" for Muon'
  188. def get_adamuon_lr_scale(
  189. param_shape: torch.Size,
  190. adjust_lr_fn: str = "match_rms_adamw",
  191. ) -> Tuple[float, bool]:
  192. """Adjust learning rate based on parameter shape for AdaMuon.
  193. Args:
  194. param_shape: Shape of the parameter tensor
  195. adjust_lr_fn: Scaling function name
  196. Returns:
  197. Tuple of (scale_factor, use_rms_norm)
  198. """
  199. out_chs, in_chs = (param_shape[-2], param_shape[-1]) if len(param_shape) > 1 else (1., 1.)
  200. if adjust_lr_fn == "match_rms_adamw":
  201. # AdaMuon paper: normalize by RMS, then scale by 0.2 * sqrt(numel)
  202. # https://arxiv.org/abs/2507.11005
  203. return 0.2 * (out_chs * in_chs) ** 0.5, True
  204. elif adjust_lr_fn == "rms_to_rms":
  205. return (out_chs / in_chs) ** 0.5, False
  206. elif adjust_lr_fn == "rsqrt_in":
  207. return in_chs ** -0.5, False
  208. else:
  209. assert False, f'Invalid scaling function "{adjust_lr_fn}" for AdaMuon'
  210. def _is_suitable_for_muon(
  211. param: torch.Tensor,
  212. min_dim_size: int = 4,
  213. max_aspect_ratio: float = 128.,
  214. return_reason: bool = False,
  215. ) -> Union[bool, Tuple[bool, str]]:
  216. """Check if a parameter is suitable for Muon optimization.
  217. Args:
  218. param: Parameter tensor
  219. min_dim_size: Minimum size for non-unit dimensions
  220. max_aspect_ratio: Maximum allowed aspect ratio
  221. return_reason: If True, return (bool, reason_string), else just bool (faster)
  222. Returns:
  223. If return_reason=False: bool indicating suitability
  224. If return_reason=True: Tuple of (is_suitable, reason_string)
  225. Examples:
  226. (64, 128) -> True (or (True, "ok") if return_reason=True)
  227. (96, 3, 4, 4) -> True - will be flattened to (96, 48)
  228. (4, 2048) -> False - extreme aspect ratio
  229. (64,) -> False - insufficient dims
  230. (1, 196, 768) -> False - leading unit dims
  231. NOTE: these rules were created to balance complexity with covering common timm model cases
  232. Please let me know if there are non-optimal cases that you run into.
  233. """
  234. s = param.shape
  235. # Must have at least 2 non-unit dimensions
  236. if param.ndim < 2 or sum(1 for dim_size in s if dim_size > 1) < 2:
  237. return (False, "insufficient_dims") if return_reason else False
  238. # Unit dimension in first two positions indicates:
  239. # - Position embeddings (1, seq, dim)
  240. # - Depthwise convs (out, 1, h, w)
  241. # - Other degenerate cases possibly not caught by first rule
  242. if s[0] == 1 or s[1] == 1:
  243. return (False, "leading_unit_dims") if return_reason else False
  244. if param.ndim >= 3:
  245. # For 3D+ tensors, check what dimensions will be AFTER flattening
  246. # since that's what gets passed to Newton-Schulz iteration
  247. # Flatten mode: (out, in, *spatial) -> (out, in * spatial_prod)
  248. out_ch = s[0]
  249. in_ch_with_spatial = 1
  250. for d in s[1:]:
  251. in_ch_with_spatial *= d
  252. check_dims = (out_ch, in_ch_with_spatial)
  253. else:
  254. # For 2D tensors, check as-is
  255. check_dims = s
  256. # Both dims should be >= minimum size
  257. min_size = min(check_dims)
  258. if min_size < min_dim_size:
  259. if return_reason:
  260. return False, f"min_dim_too_small:{min_size}"
  261. return False
  262. # Aspect ratio shouldn't be too extreme
  263. max_size = max(check_dims)
  264. aspect_ratio = max_size / min_size
  265. if aspect_ratio > max_aspect_ratio:
  266. if return_reason:
  267. return False, f"extreme_aspect_ratio:{aspect_ratio:.1f}"
  268. return False
  269. return (True, "ok") if return_reason else True
  270. def reshape_for_muon(
  271. tensor: torch.Tensor,
  272. mode: str = "flatten",
  273. ) -> Tuple[torch.Tensor, torch.Size]:
  274. """Reshape high-dimensional tensor for Muon processing.
  275. Args:
  276. tensor: Input tensor of shape (out, in, *spatial)
  277. mode: How to handle spatial dimensions
  278. - "flatten": Flatten spatial into output dimension (out, in*H*W)
  279. - "batched": Batch over spatial positions (spatial_prod, out, in) for per-position orthogonalization
  280. Returns:
  281. Reshaped tensor and original shape for restoration
  282. """
  283. original_shape = tensor.shape
  284. if tensor.ndim == 2:
  285. return tensor, original_shape
  286. if tensor.ndim < 2:
  287. raise ValueError(f"Tensor must have at least 2 dimensions, got {tensor.ndim}")
  288. out_ch, in_ch = tensor.shape[:2]
  289. if mode == "flatten":
  290. # Flatten: (out, in, *spatial) -> (out, in * spatial_prod)
  291. return tensor.reshape(out_ch, -1), original_shape
  292. elif mode == "batched":
  293. # Batched: (out, in, *spatial) -> (spatial_prod, out, in)
  294. # Move spatial dimension to front so zeropower_via_newtonschulz batches over it
  295. reshaped = tensor.reshape(out_ch, in_ch, -1) # (out, in, spatial_prod)
  296. reshaped = reshaped.permute(2, 0, 1) # (spatial_prod, out, in)
  297. return reshaped, original_shape
  298. else:
  299. raise ValueError(f"Unknown mode: {mode}")
  300. def muon(
  301. params: List[torch.Tensor],
  302. grads: List[torch.Tensor],
  303. momentum_bufs: List[torch.Tensor],
  304. *,
  305. lr: float,
  306. weight_decay: float,
  307. momentum: float,
  308. nesterov: bool,
  309. ns_steps: int,
  310. ns_coefficients: NSCoeff,
  311. eps: float,
  312. safety_factor: float,
  313. adjust_lr_fn: Optional[str],
  314. conv_mode: str,
  315. normalize_spatial: bool,
  316. scale_eps: bool,
  317. ) -> None:
  318. """Functional API that performs Muon algorithm computation."""
  319. _single_tensor_muon(
  320. params,
  321. grads,
  322. momentum_bufs,
  323. lr=lr,
  324. weight_decay=weight_decay,
  325. momentum=momentum,
  326. nesterov=nesterov,
  327. ns_steps=ns_steps,
  328. ns_coefficients=ns_coefficients,
  329. eps=eps,
  330. safety_factor=safety_factor,
  331. adjust_lr_fn=adjust_lr_fn,
  332. conv_mode=conv_mode,
  333. normalize_spatial=normalize_spatial,
  334. scale_eps=scale_eps,
  335. )
  336. def adamuon(
  337. params: List[torch.Tensor],
  338. grads: List[torch.Tensor],
  339. momentum_bufs: List[torch.Tensor],
  340. exp_avg_sqs: List[torch.Tensor],
  341. state_steps: List[torch.Tensor],
  342. *,
  343. lr: float,
  344. weight_decay: float,
  345. momentum: float,
  346. nesterov: bool,
  347. beta2: float,
  348. ns_steps: int,
  349. ns_coefficients: NSCoeff,
  350. eps: float,
  351. safety_factor: float,
  352. adjust_lr_fn: Optional[str],
  353. conv_mode: str,
  354. normalize_spatial: bool,
  355. scale_eps: bool,
  356. ) -> None:
  357. """Functional API that performs AdaMuon algorithm computation.
  358. AdaMuon extends Muon with element-wise second moment estimation applied
  359. to orthogonalized update directions, providing Adam-like adaptive scaling
  360. while preserving Muon's geometric benefits.
  361. Reference: https://arxiv.org/abs/2507.11005
  362. """
  363. _single_tensor_adamuon(
  364. params,
  365. grads,
  366. momentum_bufs,
  367. exp_avg_sqs,
  368. state_steps,
  369. lr=lr,
  370. weight_decay=weight_decay,
  371. momentum=momentum,
  372. nesterov=nesterov,
  373. beta2=beta2,
  374. ns_steps=ns_steps,
  375. ns_coefficients=ns_coefficients,
  376. eps=eps,
  377. safety_factor=safety_factor,
  378. adjust_lr_fn=adjust_lr_fn,
  379. conv_mode=conv_mode,
  380. normalize_spatial=normalize_spatial,
  381. scale_eps=scale_eps,
  382. )
  383. def _single_tensor_muon(
  384. params: List[torch.Tensor],
  385. grads: List[torch.Tensor],
  386. momentum_bufs: List[torch.Tensor],
  387. *,
  388. lr: float,
  389. weight_decay: float,
  390. momentum: float,
  391. nesterov: bool,
  392. ns_steps: int,
  393. ns_coefficients: NSCoeff,
  394. eps: float,
  395. safety_factor: float,
  396. adjust_lr_fn: Optional[str],
  397. conv_mode: str,
  398. normalize_spatial: bool,
  399. scale_eps: bool,
  400. ) -> None:
  401. """Single tensor Muon update."""
  402. ns_coefficients = resolve_ns_coefficients(ns_coefficients, _COEFFICIENTS)
  403. for i, param in enumerate(params):
  404. grad = grads[i]
  405. momentum_buf = momentum_bufs[i]
  406. # Apply weight decay
  407. param.mul_(1 - lr * weight_decay)
  408. # Update momentum buffer
  409. momentum_buf.lerp_(grad, 1. - momentum)
  410. update = grad.lerp_(momentum_buf, momentum) if nesterov else momentum_buf.clone()
  411. # Reshape for processing (handle 3D+ tensors like conv weights)
  412. if update.ndim >= 3:
  413. update_reshaped, original_shape = reshape_for_muon(update, mode=conv_mode)
  414. else:
  415. update_reshaped = update
  416. original_shape = update.shape
  417. # Apply Newton-Schulz orthogonalization
  418. update_ortho = zeropower_via_newtonschulz(
  419. update_reshaped,
  420. ns_steps,
  421. ns_coefficients,
  422. eps=eps,
  423. safety_factor=safety_factor,
  424. scale_eps=scale_eps,
  425. )
  426. # Adjust learning rate based on parameter shape
  427. if adjust_lr_fn:
  428. scale = get_lr_scale(update_ortho.shape, adjust_lr_fn)
  429. else:
  430. scale = 1.0
  431. # Apply spatial normalization and permute back if in batched mode
  432. if conv_mode == "batched" and update_ortho.ndim >= 3:
  433. if normalize_spatial:
  434. scale *= update_ortho.shape[0] ** -0.5
  435. # Permute back: (spatial_prod, out, in) -> (out, in, spatial_prod)
  436. update_ortho = update_ortho.permute(1, 2, 0)
  437. # Reshape back to original shape
  438. update_ortho = update_ortho.reshape(original_shape)
  439. # Apply update
  440. param.add_(update_ortho, alpha=-lr * scale)
  441. def _single_tensor_adamuon(
  442. params: List[torch.Tensor],
  443. grads: List[torch.Tensor],
  444. momentum_bufs: List[torch.Tensor],
  445. exp_avg_sqs: List[torch.Tensor],
  446. state_steps: List[torch.Tensor],
  447. *,
  448. lr: float,
  449. weight_decay: float,
  450. momentum: float,
  451. nesterov: bool,
  452. beta2: float,
  453. ns_steps: int,
  454. ns_coefficients: NSCoeff,
  455. eps: float,
  456. safety_factor: float,
  457. adjust_lr_fn: Optional[str],
  458. conv_mode: str,
  459. normalize_spatial: bool,
  460. scale_eps: bool,
  461. ) -> None:
  462. """Single tensor AdaMuon update.
  463. AdaMuon applies second-moment estimation to the orthogonalized directions,
  464. then rescales using RMS-alignment to maintain stable step sizes.
  465. Algorithm:
  466. 1. Update momentum buffer: M = β₁·M + (1-β₁)·G
  467. 2. Orthogonalize: O = Newton-Schulz(M) or Newton-Schulz(nesterov_update)
  468. 3. Update second moment: v = β₂·v + (1-β₂)·O²
  469. 4. Bias correct: v̂ = v/(1-β₂^t)
  470. 5. Adaptive scaling: Ô = O / (√v̂ + ε)
  471. 6. RMS-aligned rescaling and apply update
  472. """
  473. ns_coefficients = resolve_ns_coefficients(ns_coefficients, _COEFFICIENTS)
  474. for i, param in enumerate(params):
  475. grad = grads[i]
  476. momentum_buf = momentum_bufs[i]
  477. exp_avg_sq = exp_avg_sqs[i]
  478. step_t = state_steps[i]
  479. # Increment step
  480. step_t += 1
  481. step = step_t.item()
  482. # Apply weight decay (decoupled)
  483. param.mul_(1 - lr * weight_decay)
  484. # Update momentum buffer
  485. momentum_buf.lerp_(grad, 1. - momentum)
  486. update = grad.lerp_(momentum_buf, momentum) if nesterov else momentum_buf.clone()
  487. # Reshape for processing (handle 3D+ tensors like conv weights)
  488. if update.ndim >= 3:
  489. update_reshaped, original_shape = reshape_for_muon(update, mode=conv_mode)
  490. else:
  491. update_reshaped = update
  492. original_shape = update.shape
  493. # Apply Newton-Schulz orthogonalization
  494. update_ortho = zeropower_via_newtonschulz(
  495. update_reshaped,
  496. ns_steps,
  497. ns_coefficients,
  498. eps=eps,
  499. safety_factor=safety_factor,
  500. scale_eps=scale_eps,
  501. )
  502. # Reshape back to original shape for second moment tracking
  503. if conv_mode == "batched" and update_ortho.ndim >= 3:
  504. # Permute back: (spatial_prod, out, in) -> (out, in, spatial_prod)
  505. update_ortho = update_ortho.permute(1, 2, 0)
  506. update_ortho = update_ortho.reshape(original_shape)
  507. # Update second moment on orthogonalized directions (element-wise)
  508. exp_avg_sq.mul_(beta2).addcmul_(update_ortho, update_ortho, value=1.0 - beta2)
  509. # Get shape-based LR scaling and whether to apply RMS normalization
  510. if adjust_lr_fn:
  511. scale, use_rms_norm = get_adamuon_lr_scale(update_ortho.shape, adjust_lr_fn)
  512. else:
  513. scale, use_rms_norm = 1.0, False
  514. if use_rms_norm:
  515. # Bias correction not needed if scaling by norm
  516. denom = exp_avg_sq.sqrt().add_(eps)
  517. else:
  518. # Bias correction for second moment
  519. bias_correction2 = 1.0 - beta2 ** step
  520. denom = (exp_avg_sq / bias_correction2).sqrt().add_(eps)
  521. # Adaptive scaling: divide by sqrt of bias-corrected second moment
  522. # This is the key AdaMuon modification
  523. update_adaptive = update_ortho / denom
  524. # RMS-aligned rescaling: normalize by update norm, then scale by shape factor
  525. # Used by AdaMuon paper approach (match_rms_adamw), not by μP approach (rms_to_rms)
  526. if use_rms_norm:
  527. # eq(8) in AdaMuon paper, 0.2 / RMS(update) = 0.2 * sqrt(ndim) / frob(update)
  528. update_norm = update_adaptive.norm().add_(eps)
  529. update_adaptive = update_adaptive / update_norm
  530. # Apply spatial normalization if in batched mode
  531. if conv_mode == "batched" and len(original_shape) >= 3:
  532. if normalize_spatial:
  533. spatial_prod = 1
  534. for d in original_shape[2:]:
  535. spatial_prod *= d
  536. scale *= spatial_prod ** -0.5
  537. # Apply update
  538. param.add_(update_adaptive, alpha=-lr * scale)
  539. class Muon(torch.optim.Optimizer):
  540. """Muon - MomentUm Orthogonalized by Newton-schulz
  541. Combines Muon for 2D+ parameters (weight matrices) with AdamW for 1D parameters (biases, norms) and
  542. parameter groups with 'use_fallback=True' set (or 'use_muon=False' for compatibility).
  543. Supports two algorithms:
  544. - "muon": Standard Muon algorithm with momentum + orthogonalization
  545. - "adamuon": AdaMuon algorithm that adds element-wise second moment estimation
  546. to orthogonalized directions for Adam-like adaptive scaling
  547. """
  548. def __init__(
  549. self,
  550. params: ParamsT,
  551. lr: float = 0.02,
  552. weight_decay: float = 0,
  553. momentum: float = 0.95,
  554. nesterov: bool = False,
  555. ns_steps: int = DEFAULT_NS_STEPS,
  556. ns_coefficients: NSCoeff = "quintic",
  557. eps: float = MUON_EPS,
  558. safety_factor: float = 1.0,
  559. adjust_lr_fn: Optional[str] = "match_rms_adamw",
  560. conv_mode: str = "flatten",
  561. normalize_spatial: bool = True,
  562. adamw_lr: Optional[float] = None,
  563. betas: Tuple[float, float] = (0.9, 0.95),
  564. algo: str = "muon",
  565. scale_eps: bool = False,
  566. verbose: bool = False,
  567. ):
  568. """ Create Muon optimizer.
  569. Args:
  570. params: Iterable of parameters or dicts defining parameter groups
  571. lr: Learning rate (default: 0.02 for Muon parameters)
  572. weight_decay: Weight decay coefficient
  573. momentum: Momentum factor for Muon
  574. nesterov: Whether to use Nesterov momentum
  575. ns_steps: Number of Newton-Schulz iterations
  576. ns_coefficients: Coefficients for NS iteration
  577. eps: Numerical stability epsilon
  578. safety_factor: Multiplicative safety factor for NS norm
  579. adjust_lr_fn: LR adjustment function - "original", "match_rms_adamw", or "rms_to_rms".
  580. For adamuon mode, can set to None to disable (RMS rescaling handles scaling).
  581. conv_mode: How to handle convolutions - "flatten" or "batched"
  582. normalize_spatial: Whether to normalize by sqrt(spatial_size) in batched mode
  583. adamw_lr: Learning rate for AdamW (1D params), defaults to lr if not specified
  584. betas: Beta coefficients - (beta1, beta2) where beta1 is used for AdamW fallback
  585. and beta2 is used for both AdamW fallback and AdaMuon second moment
  586. algo: Algorithm - "muon" for standard Muon, "adamuon" for AdaMuon with
  587. adaptive second moment estimation (https://arxiv.org/abs/2507.11005)
  588. scale_eps: If True, scale epsilon by sqrt(din/dout) in Newton-Schulz for μP
  589. compatibility (https://arxiv.org/abs/2512.05620)
  590. verbose: Log parameter routing decisions (Muon vs AdamW)
  591. Example:
  592. ```python
  593. # Simple usage - automatically uses Muon for 2D+ params, AdamW for 1D
  594. optimizer = Muon(model.parameters(), lr=0.02)
  595. # Use AdaMuon algorithm for adaptive scaling
  596. optimizer = Muon(model.parameters(), lr=6e-4, algo="adamuon")
  597. # Manual control over parameter groups
  598. optimizer = Muon([
  599. {'params': weight_matrices, 'lr': 0.02},
  600. {'params': biases, 'use_fallback': True, 'lr': 3e-4}, # use AdamW if use_fallback=True
  601. ])
  602. ```
  603. """
  604. if not 0.0 <= lr:
  605. raise ValueError(f"Invalid learning rate: {lr}")
  606. if not 0.0 <= weight_decay:
  607. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  608. if not 0.0 <= momentum < 1.0:
  609. raise ValueError(f"Invalid momentum value: {momentum}")
  610. if not 0.0 <= eps:
  611. raise ValueError(f"Invalid epsilon value: {eps}")
  612. if conv_mode not in ["flatten", "batched"]:
  613. raise ValueError(f"Invalid conv_mode: {conv_mode}")
  614. if algo not in ["muon", "adamuon"]:
  615. raise ValueError(f"Invalid algo: {algo}. Must be 'muon' or 'adamuon'")
  616. defaults = dict(
  617. lr=lr,
  618. weight_decay=weight_decay,
  619. momentum=momentum,
  620. nesterov=nesterov,
  621. ns_steps=ns_steps,
  622. ns_coefficients=ns_coefficients,
  623. eps=eps,
  624. safety_factor=safety_factor,
  625. adjust_lr_fn=adjust_lr_fn,
  626. conv_mode=conv_mode,
  627. normalize_spatial=normalize_spatial,
  628. adamw_lr=adamw_lr if adamw_lr is not None else lr,
  629. betas=betas,
  630. algo=algo,
  631. scale_eps=scale_eps,
  632. verbose=verbose,
  633. )
  634. super().__init__(params, defaults)
  635. def __setstate__(self, state):
  636. super().__setstate__(state)
  637. for group in self.param_groups:
  638. group.setdefault('algo', 'muon')
  639. group.setdefault('scale_eps', False)
  640. @torch.no_grad()
  641. def step(self, closure=None):
  642. """Performs a single optimization step."""
  643. loss = None
  644. if closure is not None:
  645. with torch.enable_grad():
  646. loss = closure()
  647. verbose = self.defaults.get("verbose", False)
  648. # Tracking for logging (populated on first encounter of each param)
  649. muon_count = 0
  650. adamw_count = 0
  651. routing_reasons = {} if verbose else None
  652. for group in self.param_groups:
  653. algo = group.get("algo", "muon")
  654. # Separate params into Muon and AdamW groups
  655. muon_params = []
  656. muon_grads = []
  657. muon_momentum_bufs = []
  658. # Additional state for adamuon mode
  659. muon_exp_avg_sqs = []
  660. muon_state_steps = []
  661. adamw_params = []
  662. adamw_grads = []
  663. adamw_exp_avgs = []
  664. adamw_exp_avg_sqs = []
  665. adamw_state_steps = []
  666. for p in group["params"]:
  667. if p.grad is None:
  668. continue
  669. if p.grad.is_sparse:
  670. raise RuntimeError("Muon does not support sparse gradients")
  671. state = self.state[p]
  672. # Determine routing on first encounter (cache in state)
  673. if "use_muon" not in state:
  674. # Check explicit flags first (support both 'use_fallback' and 'use_muon' for compatibility)
  675. reason = None
  676. if group.get("use_fallback", False):
  677. # use_fallback=True means use AdamW (use_muon=False)
  678. state["use_muon"] = False
  679. if verbose:
  680. reason = "use_fallback_flag"
  681. elif "use_muon" in group:
  682. # Explicit use_muon flag for compatibility with other Muon implementations
  683. state["use_muon"] = group["use_muon"]
  684. if verbose:
  685. reason = "use_muon_flag"
  686. else:
  687. # Check shape suitability
  688. if verbose:
  689. suitable, reason = _is_suitable_for_muon(p, return_reason=True)
  690. else:
  691. suitable = _is_suitable_for_muon(p, return_reason=False)
  692. state["use_muon"] = suitable
  693. # Track routing decision for logging
  694. if routing_reasons is not None and reason is not None:
  695. shape_str = "x".join(str(s) for s in p.shape)
  696. if shape_str not in routing_reasons:
  697. routing_reasons[shape_str] = []
  698. routing_reasons[shape_str].append(reason)
  699. # Use cached routing decision
  700. use_muon = state["use_muon"]
  701. if use_muon:
  702. # Collect Muon params
  703. muon_params.append(p)
  704. muon_grads.append(p.grad)
  705. muon_count += 1
  706. # State initialization for Muon/AdaMuon
  707. if "momentum_buffer" not in state:
  708. state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
  709. muon_momentum_bufs.append(state["momentum_buffer"])
  710. # Additional state for adamuon mode
  711. if algo == "adamuon":
  712. if "step" not in state:
  713. state["step"] = torch.tensor(0.)
  714. state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
  715. muon_exp_avg_sqs.append(state["exp_avg_sq"])
  716. muon_state_steps.append(state["step"])
  717. else:
  718. # Collect AdamW/NAdamW params
  719. adamw_params.append(p)
  720. adamw_grads.append(p.grad)
  721. adamw_count += 1
  722. # State initialization for AdamW
  723. if "step" not in state:
  724. state["step"] = torch.tensor(0.)
  725. state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
  726. state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
  727. adamw_exp_avgs.append(state["exp_avg"])
  728. adamw_exp_avg_sqs.append(state["exp_avg_sq"])
  729. adamw_state_steps.append(state["step"])
  730. # Apply Muon/AdaMuon updates
  731. if muon_params:
  732. if algo == "adamuon":
  733. _, beta2 = group["betas"]
  734. adamuon(
  735. muon_params,
  736. muon_grads,
  737. muon_momentum_bufs,
  738. muon_exp_avg_sqs,
  739. muon_state_steps,
  740. lr=group["lr"],
  741. weight_decay=group["weight_decay"],
  742. momentum=group["momentum"],
  743. nesterov=group["nesterov"],
  744. beta2=beta2,
  745. ns_steps=group["ns_steps"],
  746. ns_coefficients=group["ns_coefficients"],
  747. eps=group["eps"],
  748. safety_factor=group["safety_factor"],
  749. adjust_lr_fn=group["adjust_lr_fn"],
  750. conv_mode=group["conv_mode"],
  751. normalize_spatial=group["normalize_spatial"],
  752. scale_eps=group["scale_eps"],
  753. )
  754. else:
  755. muon(
  756. muon_params,
  757. muon_grads,
  758. muon_momentum_bufs,
  759. lr=group["lr"],
  760. weight_decay=group["weight_decay"],
  761. momentum=group["momentum"],
  762. nesterov=group["nesterov"],
  763. ns_steps=group["ns_steps"],
  764. ns_coefficients=group["ns_coefficients"],
  765. eps=group["eps"],
  766. safety_factor=group["safety_factor"],
  767. adjust_lr_fn=group["adjust_lr_fn"],
  768. conv_mode=group["conv_mode"],
  769. normalize_spatial=group["normalize_spatial"],
  770. scale_eps=group["scale_eps"],
  771. )
  772. # Apply AdamW updates
  773. if adamw_params:
  774. beta1, beta2 = group["betas"]
  775. if group["nesterov"]:
  776. # use nadamw for fallback optimizer if nesterov is enabled
  777. nadamw(
  778. adamw_params,
  779. adamw_grads,
  780. adamw_exp_avgs,
  781. adamw_exp_avg_sqs,
  782. adamw_state_steps,
  783. foreach=None,
  784. beta1=beta1,
  785. beta2=beta2,
  786. lr=group["adamw_lr"],
  787. weight_decay=group["weight_decay"],
  788. eps=group["eps"],
  789. caution=False,
  790. maximize=False,
  791. capturable=False,
  792. max_lr=None,
  793. )
  794. else:
  795. adamw(
  796. adamw_params,
  797. adamw_grads,
  798. adamw_exp_avgs,
  799. adamw_exp_avg_sqs,
  800. [], # max_exp_avg_sqs (not using amsgrad)
  801. adamw_state_steps,
  802. foreach=None,
  803. amsgrad=False,
  804. beta1=beta1,
  805. beta2=beta2,
  806. lr=group["adamw_lr"],
  807. weight_decay=group["weight_decay"],
  808. eps=group["eps"],
  809. caution=False,
  810. maximize=False,
  811. capturable=False,
  812. max_lr=None,
  813. )
  814. # Log routing summary when we have new routing decisions
  815. if routing_reasons and len(routing_reasons) > 0:
  816. # Concise summary
  817. _logger.info(f"Muon parameter routing: {muon_count} Muon, {adamw_count} AdamW")
  818. # Group by reason for detailed breakdown
  819. reason_groups = {}
  820. for shape_str, reasons in sorted(routing_reasons.items()):
  821. for reason in reasons:
  822. if reason not in reason_groups:
  823. reason_groups[reason] = []
  824. reason_groups[reason].append(shape_str)
  825. # Log summary counts per reason
  826. reason_summary = []
  827. for reason, shapes in sorted(reason_groups.items()):
  828. reason_summary.append(f"{reason}={len(shapes)}")
  829. _logger.info(f" Breakdown: {', '.join(reason_summary)}")
  830. # Detailed breakdown at INFO level
  831. if _logger.isEnabledFor(logging.INFO):
  832. for reason, shapes in sorted(reason_groups.items()):
  833. optimizer_name = "Muon" if reason == "ok" else "AdamW"
  834. _logger.info(f" {reason} -> {optimizer_name}:")
  835. for shape in shapes[:10]:
  836. _logger.info(f" {shape}")
  837. if len(shapes) > 10:
  838. _logger.info(f" ... and {len(shapes) - 10} more")
  839. return loss
  840. def resolve_ns_coefficients(
  841. value: Union[str, Sequence[float], Sequence[Sequence[float]]],
  842. presets: Mapping[str, Sequence[Sequence[float]]]
  843. ) -> List[Tuple[float, float, float]]:
  844. # tiny helpers (kept inline for succinctness)
  845. is_seq = lambda x: isinstance(x, Sequence) and not isinstance(x, (str, bytes))
  846. is_real = lambda x: isinstance(x, numbers.Real) and not isinstance(x, bool)
  847. def as_coeff(x: Sequence[float]) -> Tuple[float, float, float]:
  848. if not is_seq(x) or len(x) != 3 or not all(is_real(v) for v in x):
  849. raise ValueError(f"Coefficient must be length-3 of real numbers, got: {x!r}")
  850. a, b, c = x # type: ignore[misc]
  851. return float(a), float(b), float(c)
  852. if isinstance(value, str):
  853. if value not in presets:
  854. valid = ", ".join(sorted(presets.keys()))
  855. raise ValueError(f"Unknown coefficients preset '{value}'. Valid options: {valid}")
  856. seq = presets[value]
  857. if not is_seq(seq) or len(seq) == 0:
  858. raise ValueError(f"Preset '{value}' is empty or invalid")
  859. return [as_coeff(item) for item in seq] # validate & cast
  860. if not is_seq(value):
  861. raise TypeError(
  862. "Coefficients must be a preset name (str), a 3-sequence (a,b,c), "
  863. "or a sequence of 3-sequences."
  864. )
  865. # Decide single triple vs list-of-triples by structure
  866. if len(value) == 3 and all(is_real(v) for v in value): # type: ignore[index]
  867. return [as_coeff(value)] # single triple -> wrap
  868. # Otherwise treat as list/tuple of triples
  869. out = []
  870. for i, item in enumerate(value): # type: ignore[assignment]
  871. if not is_seq(item):
  872. raise TypeError(f"Item {i} is not a sequence: {item!r}")
  873. out.append(as_coeff(item))
  874. if not out:
  875. raise ValueError("Coefficient list cannot be empty")
  876. return out