_optim_factory.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340
  1. """ Optimizer Factory w/ custom Weight Decay & Layer Decay support
  2. Hacked together by / Copyright 2021 Ross Wightman
  3. """
  4. import logging
  5. from dataclasses import dataclass
  6. from functools import partial
  7. from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Type, Union
  8. from fnmatch import fnmatch
  9. import importlib
  10. import torch
  11. import torch.nn as nn
  12. import torch.optim
  13. from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
  14. from ._types import ParamsT, OptimType, OptimizerCallable
  15. from .adabelief import AdaBelief
  16. from .adafactor import Adafactor
  17. from .adafactor_bv import AdafactorBigVision
  18. from .adahessian import Adahessian
  19. from .adamp import AdamP
  20. from .adamw import AdamWLegacy
  21. from .adan import Adan
  22. from .adopt import Adopt
  23. from .kron import Kron
  24. from .lamb import Lamb
  25. from .laprop import LaProp
  26. from .lars import Lars
  27. from .lion import Lion
  28. from .lookahead import Lookahead
  29. from .madgrad import MADGRAD
  30. from .mars import Mars
  31. from .muon import Muon
  32. from .nadam import NAdamLegacy
  33. from .nadamw import NAdamW
  34. from .nvnovograd import NvNovoGrad
  35. from .radam import RAdamLegacy
  36. from .rmsprop_tf import RMSpropTF
  37. from .sgdp import SGDP
  38. from .sgdw import SGDW
  39. _logger = logging.getLogger(__name__)
  40. def _import_class(class_string: str) -> Type:
  41. """Dynamically import a class from a string."""
  42. try:
  43. module_name, class_name = class_string.rsplit(".", 1)
  44. module = importlib.import_module(module_name)
  45. return getattr(module, class_name)
  46. except (ImportError, AttributeError) as e:
  47. raise ImportError(f"Could not import {class_string}: {e}")
  48. @dataclass(frozen=True)
  49. class OptimInfo:
  50. """Immutable configuration for an optimizer.
  51. Attributes:
  52. name: Unique identifier for the optimizer
  53. opt_class: The optimizer class
  54. description: Brief description of the optimizer's characteristics and behavior
  55. has_eps: Whether the optimizer accepts epsilon parameter
  56. has_momentum: Whether the optimizer accepts momentum parameter
  57. has_betas: Whether the optimizer accepts a tuple of beta parameters
  58. num_betas: number of betas in tuple (valid IFF has_betas = True)
  59. defaults: Optional default parameters for the optimizer
  60. """
  61. name: str
  62. opt_class: Union[str, OptimType]
  63. description: str = ''
  64. has_eps: bool = True
  65. has_momentum: bool = False
  66. has_betas: bool = False
  67. num_betas: int = 2
  68. second_order: bool = False
  69. defaults: Optional[Dict[str, Any]] = None
  70. class OptimizerRegistry:
  71. """Registry managing optimizer configurations and instantiation.
  72. This class provides a central registry for optimizer configurations and handles
  73. their instantiation with appropriate parameter groups and settings.
  74. """
  75. def __init__(self) -> None:
  76. self._optimizers: Dict[str, OptimInfo] = {}
  77. self._foreach_defaults: Set[str] = {'lion'}
  78. def register(self, info: OptimInfo) -> None:
  79. """Register an optimizer configuration.
  80. Args:
  81. info: The OptimInfo configuration containing name, type and description
  82. """
  83. name = info.name.lower()
  84. if name in self._optimizers:
  85. _logger.warning(f'Optimizer {name} already registered, overwriting')
  86. self._optimizers[name] = info
  87. def register_alias(self, alias: str, target: str) -> None:
  88. """Register an alias for an existing optimizer.
  89. Args:
  90. alias: The alias name
  91. target: The target optimizer name
  92. Raises:
  93. KeyError: If target optimizer doesn't exist
  94. """
  95. target = target.lower()
  96. if target not in self._optimizers:
  97. raise KeyError(f'Cannot create alias for non-existent optimizer {target}')
  98. self._optimizers[alias.lower()] = self._optimizers[target]
  99. def register_foreach_default(self, name: str) -> None:
  100. """Register an optimizer as defaulting to foreach=True."""
  101. self._foreach_defaults.add(name.lower())
  102. def list_optimizers(
  103. self,
  104. filter: Union[str, List[str]] = '',
  105. exclude_filters: Optional[List[str]] = None,
  106. with_description: bool = False
  107. ) -> List[Union[str, Tuple[str, str]]]:
  108. """List available optimizer names, optionally filtered.
  109. Args:
  110. filter: Wildcard style filter string (e.g., 'adam*')
  111. exclude_filters: Optional list of wildcard patterns to exclude
  112. with_description: If True, return tuples of (name, description)
  113. Returns:
  114. List of either optimizer names or (name, description) tuples
  115. """
  116. names = sorted(self._optimizers.keys())
  117. if filter:
  118. if isinstance(filter, str):
  119. filters = [filter]
  120. else:
  121. filters = filter
  122. filtered_names = set()
  123. for f in filters:
  124. filtered_names.update(n for n in names if fnmatch(n, f))
  125. names = sorted(filtered_names)
  126. if exclude_filters:
  127. for exclude_filter in exclude_filters:
  128. names = [n for n in names if not fnmatch(n, exclude_filter)]
  129. if with_description:
  130. return [(name, self._optimizers[name].description) for name in names]
  131. return names
  132. def get_optimizer_info(self, name: str) -> OptimInfo:
  133. """Get the OptimInfo for an optimizer.
  134. Args:
  135. name: Name of the optimizer
  136. Returns:
  137. OptimInfo configuration
  138. Raises:
  139. ValueError: If optimizer is not found
  140. """
  141. name = name.lower()
  142. if name not in self._optimizers:
  143. raise ValueError(f'Optimizer {name} not found in registry')
  144. return self._optimizers[name]
  145. def get_optimizer_class(
  146. self,
  147. name_or_info: Union[str, OptimInfo],
  148. bind_defaults: bool = True,
  149. ) -> Union[OptimType, OptimizerCallable]:
  150. """Get the optimizer class with any default arguments applied.
  151. This allows direct instantiation of optimizers with their default configs
  152. without going through the full factory.
  153. Args:
  154. name_or_info: Name of the optimizer
  155. bind_defaults: Bind default arguments to optimizer class via `partial` before returning
  156. Returns:
  157. Optimizer class or partial with defaults applied
  158. Raises:
  159. ValueError: If optimizer not found
  160. """
  161. if isinstance(name_or_info, str):
  162. opt_info = self.get_optimizer_info(name_or_info)
  163. else:
  164. assert isinstance(name_or_info, OptimInfo)
  165. opt_info = name_or_info
  166. if isinstance(opt_info.opt_class, str):
  167. # Special handling for APEX and BNB optimizers
  168. if opt_info.opt_class.startswith('apex.'):
  169. assert torch.cuda.is_available(), 'CUDA required for APEX optimizers'
  170. try:
  171. opt_class = _import_class(opt_info.opt_class)
  172. except ImportError as e:
  173. raise ImportError('APEX optimizers require apex to be installed') from e
  174. elif opt_info.opt_class.startswith('bitsandbytes.'):
  175. assert torch.cuda.is_available(), 'CUDA required for bitsandbytes optimizers'
  176. try:
  177. opt_class = _import_class(opt_info.opt_class)
  178. except ImportError as e:
  179. raise ImportError('bitsandbytes optimizers require bitsandbytes to be installed') from e
  180. else:
  181. opt_class = _import_class(opt_info.opt_class)
  182. else:
  183. opt_class = opt_info.opt_class
  184. # Return class or partial with defaults
  185. if bind_defaults and opt_info.defaults:
  186. opt_class = partial(opt_class, **opt_info.defaults)
  187. return opt_class
  188. def create_optimizer(
  189. self,
  190. model_or_params: Union[nn.Module, ParamsT],
  191. opt: str,
  192. lr: Optional[float] = None,
  193. weight_decay: float = 0.,
  194. momentum: float = 0.9,
  195. foreach: Optional[bool] = None,
  196. weight_decay_exclude_1d: bool = True,
  197. fallback_list: Collection[str] = (),
  198. fallback_no_weight_decay: bool = False,
  199. layer_decay: Optional[float] = None,
  200. layer_decay_min_scale: Optional[float] = None,
  201. layer_decay_no_opt_scale: Optional[float] = None,
  202. param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
  203. **kwargs: Any,
  204. ) -> torch.optim.Optimizer:
  205. """Create an optimizer instance.
  206. Args:
  207. model_or_params: Model or parameters to optimize
  208. opt: Name of optimizer to create
  209. lr: Learning rate
  210. weight_decay: Weight decay factor
  211. momentum: Momentum factor for applicable optimizers
  212. foreach: Enable/disable foreach operation
  213. weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
  214. fallback_list: Collection of parameter name patterns to use fallback optimizer for hybrid optimizers
  215. fallback_no_weight_decay: If True, params in no_weight_decay list will use fallback optimizer (e.g., AdamW for Muon)
  216. layer_decay: Layer-wise learning rate decay
  217. layer_scale_min_scale: Minimum layer scale factor clamp value
  218. layer_scale_no_opt_scale: Layer scale below which optimization is disabled
  219. param_group_fn: Optional custom parameter grouping function
  220. **kwargs: Additional optimizer-specific arguments
  221. Returns:
  222. Configured optimizer instance
  223. Raises:
  224. ValueError: If optimizer not found or configuration invalid
  225. """
  226. # Get parameters to optimize
  227. if isinstance(model_or_params, nn.Module):
  228. # Extract parameters from a nn.Module, build param groups w/ weight-decay and/or layer-decay applied
  229. no_weight_decay = getattr(model_or_params, 'no_weight_decay', lambda: set())()
  230. if param_group_fn:
  231. # run custom fn to generate param groups from nn.Module
  232. params = param_group_fn(model_or_params)
  233. elif layer_decay is not None:
  234. params = param_groups_layer_decay(
  235. model_or_params,
  236. weight_decay=weight_decay,
  237. layer_decay=layer_decay,
  238. no_weight_decay_list=no_weight_decay,
  239. fallback_list=fallback_list,
  240. fallback_no_weight_decay=fallback_no_weight_decay,
  241. weight_decay_exclude_1d=weight_decay_exclude_1d,
  242. min_scale=layer_decay_min_scale,
  243. no_opt_scale=layer_decay_no_opt_scale,
  244. )
  245. weight_decay = 0.
  246. elif weight_decay and weight_decay_exclude_1d:
  247. params = param_groups_weight_decay(
  248. model_or_params,
  249. weight_decay=weight_decay,
  250. no_weight_decay_list=no_weight_decay,
  251. fallback_list=fallback_list,
  252. fallback_no_weight_decay=fallback_no_weight_decay,
  253. )
  254. weight_decay = 0.
  255. else:
  256. params = model_or_params.parameters()
  257. else:
  258. # pass parameters / parameter groups through to optimizer
  259. params = model_or_params
  260. # Parse optimizer name
  261. opt_split = opt.lower().split('_')
  262. opt_name = opt_split[-1]
  263. use_lookahead = opt_split[0] == 'lookahead' if len(opt_split) > 1 else False
  264. opt_info = self.get_optimizer_info(opt_name)
  265. # Build optimizer arguments
  266. opt_args: Dict[str, Any] = {'weight_decay': weight_decay, **kwargs}
  267. # Add LR to args, if None optimizer default is used, some optimizers manage LR internally if None.
  268. if lr is not None:
  269. opt_args['lr'] = lr
  270. # Apply optimizer-specific settings
  271. if opt_info.defaults:
  272. for k, v in opt_info.defaults.items():
  273. opt_args.setdefault(k, v)
  274. # timm has always defaulted momentum to 0.9 if optimizer supports momentum, keep for backward compat.
  275. if opt_info.has_momentum:
  276. opt_args.setdefault('momentum', momentum)
  277. # Remove commonly used kwargs that aren't always supported
  278. if not opt_info.has_eps:
  279. opt_args.pop('eps', None)
  280. if not opt_info.has_betas:
  281. opt_args.pop('betas', None)
  282. if foreach is not None:
  283. # Explicitly activate or deactivate multi-tensor foreach impl.
  284. # Not all optimizers support this, and those that do usually default to using
  285. # multi-tensor impl if foreach is left as default 'None' and can be enabled.
  286. opt_args.setdefault('foreach', foreach)
  287. # Create optimizer
  288. opt_class = self.get_optimizer_class(opt_info, bind_defaults=False)
  289. optimizer = opt_class(params, **opt_args)
  290. # Apply Lookahead if requested
  291. if use_lookahead:
  292. optimizer = Lookahead(optimizer)
  293. return optimizer
  294. def _register_sgd_variants(registry: OptimizerRegistry) -> None:
  295. """Register SGD-based optimizers"""
  296. sgd_optimizers = [
  297. OptimInfo(
  298. name='sgd',
  299. opt_class=torch.optim.SGD,
  300. description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
  301. has_eps=False,
  302. has_momentum=True,
  303. defaults={'nesterov': True}
  304. ),
  305. OptimInfo(
  306. name='momentum',
  307. opt_class=torch.optim.SGD,
  308. description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
  309. has_eps=False,
  310. has_momentum=True,
  311. defaults={'nesterov': False}
  312. ),
  313. OptimInfo(
  314. name='sgdp',
  315. opt_class=SGDP,
  316. description='SGD with built-in projection to unit norm sphere',
  317. has_momentum=True,
  318. defaults={'nesterov': True}
  319. ),
  320. OptimInfo(
  321. name='sgdw',
  322. opt_class=SGDW,
  323. description='SGD with decoupled weight decay and Nesterov momentum',
  324. has_eps=False,
  325. has_momentum=True,
  326. defaults={'nesterov': True}
  327. ),
  328. ]
  329. for opt in sgd_optimizers:
  330. registry.register(opt)
  331. def _register_adam_variants(registry: OptimizerRegistry) -> None:
  332. """Register Adam-based optimizers"""
  333. adam_optimizers = [
  334. OptimInfo(
  335. name='adam',
  336. opt_class=torch.optim.Adam,
  337. description='torch.optim.Adam, Adaptive Moment Estimation',
  338. has_betas=True
  339. ),
  340. OptimInfo(
  341. name='adamw',
  342. opt_class=torch.optim.AdamW,
  343. description='torch.optim.AdamW, Adam with decoupled weight decay',
  344. has_betas=True
  345. ),
  346. OptimInfo(
  347. name='adamwlegacy',
  348. opt_class=AdamWLegacy,
  349. description='legacy impl of AdamW that pre-dates inclusion to torch.optim',
  350. has_betas=True
  351. ),
  352. OptimInfo(
  353. name='adamp',
  354. opt_class=AdamP,
  355. description='Adam with built-in projection to unit norm sphere',
  356. has_betas=True,
  357. defaults={'wd_ratio': 0.01, 'nesterov': True}
  358. ),
  359. OptimInfo(
  360. name='nadam',
  361. opt_class=torch.optim.NAdam,
  362. description='torch.optim.NAdam, Adam with Nesterov momentum',
  363. has_betas=True
  364. ),
  365. OptimInfo(
  366. name='nadamlegacy',
  367. opt_class=NAdamLegacy,
  368. description='legacy impl of NAdam that pre-dates inclusion in torch.optim',
  369. has_betas=True
  370. ),
  371. OptimInfo(
  372. name='nadamw',
  373. opt_class=NAdamW,
  374. description='Adam with Nesterov momentum and decoupled weight decay, mlcommons/algorithmic-efficiency impl',
  375. has_betas=True
  376. ),
  377. OptimInfo(
  378. name='radam',
  379. opt_class=torch.optim.RAdam,
  380. description='torch.optim.RAdam, Rectified Adam with variance adaptation',
  381. has_betas=True
  382. ),
  383. OptimInfo(
  384. name='radamlegacy',
  385. opt_class=RAdamLegacy,
  386. description='legacy impl of RAdam that predates inclusion in torch.optim',
  387. has_betas=True
  388. ),
  389. OptimInfo(
  390. name='radamw',
  391. opt_class=torch.optim.RAdam,
  392. description='torch.optim.RAdamW, Rectified Adam with variance adaptation and decoupled weight decay',
  393. has_betas=True,
  394. defaults={'decoupled_weight_decay': True}
  395. ),
  396. OptimInfo(
  397. name='adamax',
  398. opt_class=torch.optim.Adamax,
  399. description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
  400. has_betas=True
  401. ),
  402. OptimInfo(
  403. name='adafactor',
  404. opt_class=Adafactor,
  405. description='Memory-efficient implementation of Adam with factored gradients',
  406. ),
  407. OptimInfo(
  408. name='adafactorbv',
  409. opt_class=AdafactorBigVision,
  410. description='Big Vision variant of Adafactor with factored gradients, half precision momentum',
  411. ),
  412. OptimInfo(
  413. name='adopt',
  414. opt_class=Adopt,
  415. description='Modified Adam that can converge with any β2 with the optimal rate',
  416. ),
  417. OptimInfo(
  418. name='adoptw',
  419. opt_class=Adopt,
  420. description='Modified AdamW (decoupled decay) that can converge with any β2 with the optimal rate',
  421. defaults={'decoupled': True}
  422. ),
  423. ]
  424. for opt in adam_optimizers:
  425. registry.register(opt)
  426. def _register_lamb_lars(registry: OptimizerRegistry) -> None:
  427. """Register LAMB and LARS variants"""
  428. lamb_lars_optimizers = [
  429. OptimInfo(
  430. name='lamb',
  431. opt_class=Lamb,
  432. description='Layer-wise Adaptive Moments for batch optimization',
  433. has_betas=True
  434. ),
  435. OptimInfo(
  436. name='lambc',
  437. opt_class=Lamb,
  438. description='LAMB with trust ratio clipping for stability',
  439. has_betas=True,
  440. defaults={'trust_clip': True}
  441. ),
  442. OptimInfo(
  443. name='lambw',
  444. opt_class=Lamb,
  445. description='LAMB with decoupled weight decay',
  446. has_betas=True,
  447. defaults={'decoupled_decay': True}
  448. ),
  449. OptimInfo(
  450. name='lambcw',
  451. opt_class=Lamb,
  452. description='LAMB with trust ratio clipping for stability and decoupled decay',
  453. has_betas=True,
  454. defaults={'trust_clip': True, 'decoupled_decay': True}
  455. ),
  456. OptimInfo(
  457. name='lars',
  458. opt_class=Lars,
  459. description='Layer-wise Adaptive Rate Scaling',
  460. has_momentum=True
  461. ),
  462. OptimInfo(
  463. name='larc',
  464. opt_class=Lars,
  465. description='LARS with trust ratio clipping for stability',
  466. has_momentum=True,
  467. defaults={'trust_clip': True}
  468. ),
  469. OptimInfo(
  470. name='nlars',
  471. opt_class=Lars,
  472. description='LARS with Nesterov momentum',
  473. has_momentum=True,
  474. defaults={'nesterov': True}
  475. ),
  476. OptimInfo(
  477. name='nlarc',
  478. opt_class=Lars,
  479. description='LARS with Nesterov momentum & trust ratio clipping',
  480. has_momentum=True,
  481. defaults={'nesterov': True, 'trust_clip': True}
  482. ),
  483. ]
  484. for opt in lamb_lars_optimizers:
  485. registry.register(opt)
  486. def _register_corrected_decay_optimizers(registry: OptimizerRegistry) -> None:
  487. """Register corrected weight decay optimizer variants"""
  488. corrected_optimizers = [
  489. OptimInfo(
  490. name='adamc',
  491. opt_class=AdamWLegacy,
  492. description='AdamW with corrected weight decay (lr²/max_lr scaling)',
  493. has_betas=True,
  494. defaults={'corrected_weight_decay': True}
  495. ),
  496. OptimInfo(
  497. name='nadamc',
  498. opt_class=NAdamW,
  499. description='NAdamW with corrected weight decay (lr²/max_lr scaling)',
  500. has_betas=True,
  501. defaults={'corrected_weight_decay': True}
  502. ),
  503. OptimInfo(
  504. name='sgdc',
  505. opt_class=SGDW,
  506. description='SGD with corrected decoupled weight decay (lr²/max_lr scaling)',
  507. has_eps=False,
  508. has_momentum=True,
  509. defaults={'nesterov': True, 'corrected_weight_decay': True}
  510. ),
  511. OptimInfo(
  512. name='adoptc',
  513. opt_class=Adopt,
  514. description='Adopt with corrected decoupled weight decay (lr²/max_lr scaling)',
  515. defaults={'decoupled': True, 'corrected_weight_decay': True}
  516. ),
  517. OptimInfo(
  518. name='lambcd',
  519. opt_class=Lamb,
  520. description='LAMB with corrected decoupled weight decay (lr²/max_lr scaling)',
  521. has_betas=True,
  522. defaults={'decoupled_decay': True, 'corrected_weight_decay': True}
  523. ),
  524. OptimInfo(
  525. name='kronc',
  526. opt_class=Kron,
  527. description='PSGD Kron with corrected decoupled weight decay (lr²/max_lr scaling)',
  528. has_momentum=True,
  529. defaults={'decoupled_decay': True, 'corrected_weight_decay': True}
  530. ),
  531. OptimInfo(
  532. name='lionc',
  533. opt_class=Lion,
  534. description='Lion with corrected weight decay (lr²/max_lr scaling)',
  535. has_eps=False,
  536. has_betas=True,
  537. defaults={'corrected_weight_decay': True}
  538. ),
  539. OptimInfo(
  540. name='lapropc',
  541. opt_class=LaProp,
  542. description='LaProp with corrected weight decay (lr²/max_lr scaling)',
  543. has_betas=True,
  544. defaults={'corrected_weight_decay': True}
  545. ),
  546. OptimInfo(
  547. name='rmsproptfc',
  548. opt_class=RMSpropTF,
  549. description='RMSprop TF-style with corrected decoupled weight decay (lr²/max_lr scaling)',
  550. has_momentum=True,
  551. defaults={'alpha': 0.9, 'decoupled_decay': True, 'corrected_weight_decay': True}
  552. ),
  553. OptimInfo(
  554. name='adafactorbvc',
  555. opt_class=AdafactorBigVision,
  556. description='Adafactor Big Vision with corrected weight decay (lr²/max_lr or lr/max_lr scaling)',
  557. defaults={'corrected_weight_decay': True}
  558. ),
  559. ]
  560. for opt in corrected_optimizers:
  561. registry.register(opt)
  562. # Cautious + corrected variants
  563. cautious_corrected = [
  564. OptimInfo(
  565. name='cadamc',
  566. opt_class=AdamWLegacy,
  567. description='Cautious AdamW with corrected weight decay (lr²/max_lr scaling)',
  568. has_betas=True,
  569. defaults={'caution': True, 'corrected_weight_decay': True}
  570. ),
  571. OptimInfo(
  572. name='cadoptc',
  573. opt_class=Adopt,
  574. description='Cautious Adopt with corrected decoupled weight decay (lr²/max_lr scaling)',
  575. defaults={'decoupled': True, 'caution': True, 'corrected_weight_decay': True}
  576. ),
  577. OptimInfo(
  578. name='cnadamc',
  579. opt_class=NAdamW,
  580. description='Cautious NAdamW with corrected weight decay (lr²/max_lr scaling)',
  581. has_betas=True,
  582. defaults={'caution': True, 'corrected_weight_decay': True}
  583. ),
  584. OptimInfo(
  585. name='csgdc',
  586. opt_class=SGDW,
  587. description='Cautious SGD with corrected decoupled weight decay (lr²/max_lr scaling)',
  588. has_eps=False,
  589. has_momentum=True,
  590. defaults={'nesterov': True, 'caution': True, 'corrected_weight_decay': True}
  591. ),
  592. OptimInfo(
  593. name='clionc',
  594. opt_class=Lion,
  595. description='Cautious Lion with corrected weight decay (lr²/max_lr scaling)',
  596. has_eps=False,
  597. has_betas=True,
  598. defaults={'caution': True, 'corrected_weight_decay': True}
  599. ),
  600. OptimInfo(
  601. name='cadafactorbvc',
  602. opt_class=AdafactorBigVision,
  603. description='Cautious Adafactor Big Vision with corrected weight decay',
  604. defaults={'caution': True, 'corrected_weight_decay': True}
  605. ),
  606. ]
  607. for opt in cautious_corrected:
  608. registry.register(opt)
  609. def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
  610. cautious_optimizers = [
  611. OptimInfo(
  612. name='cadafactor',
  613. opt_class=Adafactor,
  614. description='Cautious Adafactor',
  615. defaults={'caution': True}
  616. ),
  617. OptimInfo(
  618. name='cadafactorbv',
  619. opt_class=AdafactorBigVision,
  620. description='Cautious Big Vision Adafactor',
  621. defaults={'caution': True}
  622. ),
  623. OptimInfo(
  624. name='cadamw',
  625. opt_class=AdamWLegacy,
  626. description='Cautious AdamW',
  627. has_betas=True,
  628. defaults={'caution': True}
  629. ),
  630. OptimInfo(
  631. name='cadopt',
  632. opt_class=Adopt,
  633. description='Cautious Adopt',
  634. defaults={'caution': True}
  635. ),
  636. OptimInfo(
  637. name='cadan',
  638. opt_class=Adan,
  639. description='Cautious Adaptive Nesterov Momentum Algorithm',
  640. defaults={'caution': True, 'no_prox': False},
  641. has_betas=True,
  642. num_betas=3
  643. ),
  644. OptimInfo(
  645. name='cadanw',
  646. opt_class=Adan,
  647. description='Cautious Adaptive Nesterov Momentum with decoupled weight decay',
  648. defaults={'caution': True, 'no_prox': True},
  649. has_betas=True,
  650. num_betas=3
  651. ),
  652. OptimInfo(
  653. name='cadoptw',
  654. opt_class=Adopt,
  655. description='Cautious AdoptW (decoupled decay)',
  656. defaults={'decoupled': True, 'caution': True}
  657. ),
  658. OptimInfo(
  659. name='clamb',
  660. opt_class=Lamb,
  661. description='Cautious LAMB',
  662. has_betas=True,
  663. defaults={'caution': True}
  664. ),
  665. OptimInfo(
  666. name='clambw',
  667. opt_class=Lamb,
  668. description='Cautious LAMB with decoupled weight decay',
  669. has_betas=True,
  670. defaults={'caution': True, 'decoupled_decay': True}
  671. ),
  672. OptimInfo(
  673. name='claprop',
  674. opt_class=LaProp,
  675. description='Cautious LaProp',
  676. has_betas=True,
  677. defaults={'caution': True}
  678. ),
  679. OptimInfo(
  680. name='clion',
  681. opt_class=Lion,
  682. description='Cautious Lion',
  683. has_eps=False,
  684. has_betas=True,
  685. defaults = {'caution': True}
  686. ),
  687. OptimInfo(
  688. name='cmars',
  689. opt_class=Mars,
  690. description='Cautious MARS',
  691. has_betas=True,
  692. defaults={'caution': True}
  693. ),
  694. OptimInfo(
  695. name='cnadamw',
  696. opt_class=NAdamW,
  697. description='Cautious NAdamW',
  698. has_betas=True,
  699. defaults={'caution': True}
  700. ),
  701. OptimInfo(
  702. name='crmsproptf',
  703. opt_class=RMSpropTF,
  704. description='Cautious TensorFlow-style RMSprop',
  705. has_momentum=True,
  706. defaults={'alpha': 0.9, 'caution': True}
  707. ),
  708. OptimInfo(
  709. name='csgdw',
  710. opt_class=SGDW,
  711. description='Cautious SGD with decoupled weight decay and Nesterov momentum',
  712. has_eps=False,
  713. has_momentum=True,
  714. defaults={'nesterov': True, 'caution': True}
  715. ),
  716. OptimInfo(
  717. name='cadamp',
  718. opt_class=AdamP,
  719. description='Add the spherical cautious optimizer and the standard cautious optimizer to AdamP',
  720. has_betas=True,
  721. defaults={'wd_ratio': 0.01, 'nesterov': True, 'caution': True}
  722. ),
  723. OptimInfo(
  724. name='csgdp',
  725. opt_class=SGDP,
  726. description='Add the spherical cautious optimizer and the standard cautious optimizer to SGDP',
  727. has_momentum=True,
  728. defaults={'nesterov': True, 'caution': True}
  729. ),
  730. ]
  731. for opt in cautious_optimizers:
  732. registry.register(opt)
  733. def _register_other_optimizers(registry: OptimizerRegistry) -> None:
  734. """Register miscellaneous optimizers"""
  735. other_optimizers = [
  736. OptimInfo(
  737. name='adabelief',
  738. opt_class=AdaBelief,
  739. description='Adapts learning rate based on gradient prediction error',
  740. has_betas=True,
  741. defaults={'rectify': False}
  742. ),
  743. OptimInfo(
  744. name='radabelief',
  745. opt_class=AdaBelief,
  746. description='Rectified AdaBelief with variance adaptation',
  747. has_betas=True,
  748. defaults={'rectify': True}
  749. ),
  750. OptimInfo(
  751. name='adadelta',
  752. opt_class=torch.optim.Adadelta,
  753. description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
  754. ),
  755. OptimInfo(
  756. name='adagrad',
  757. opt_class=torch.optim.Adagrad,
  758. description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
  759. defaults={'eps': 1e-8}
  760. ),
  761. OptimInfo(
  762. name='adan',
  763. opt_class=Adan,
  764. description='Adaptive Nesterov Momentum Algorithm',
  765. defaults={'no_prox': False},
  766. has_betas=True,
  767. num_betas=3
  768. ),
  769. OptimInfo(
  770. name='adanw',
  771. opt_class=Adan,
  772. description='Adaptive Nesterov Momentum with decoupled weight decay',
  773. defaults={'no_prox': True},
  774. has_betas=True,
  775. num_betas=3
  776. ),
  777. OptimInfo(
  778. name='adahessian',
  779. opt_class=Adahessian,
  780. description='An Adaptive Second Order Optimizer',
  781. has_betas=True,
  782. second_order=True,
  783. ),
  784. OptimInfo(
  785. name='kron',
  786. opt_class=Kron,
  787. description='PSGD optimizer with Kronecker-factored preconditioner',
  788. has_eps=False,
  789. has_momentum=True,
  790. ),
  791. OptimInfo(
  792. name='kronw',
  793. opt_class=Kron,
  794. description='PSGD optimizer with Kronecker-factored preconditioner and decoupled weight decay',
  795. has_momentum=True,
  796. has_eps=False,
  797. defaults={'decoupled_decay': True}
  798. ),
  799. OptimInfo(
  800. name='laprop',
  801. opt_class=LaProp,
  802. description='Separating Momentum and Adaptivity in Adam',
  803. has_betas=True,
  804. ),
  805. OptimInfo(
  806. name='lion',
  807. opt_class=Lion,
  808. description='Evolved Sign Momentum optimizer for improved convergence',
  809. has_eps=False,
  810. has_betas=True
  811. ),
  812. OptimInfo(
  813. name='madgrad',
  814. opt_class=MADGRAD,
  815. description='Momentum-based Adaptive gradient method',
  816. has_momentum=True
  817. ),
  818. OptimInfo(
  819. name='madgradw',
  820. opt_class=MADGRAD,
  821. description='MADGRAD with decoupled weight decay',
  822. has_momentum=True,
  823. defaults={'decoupled_decay': True}
  824. ),
  825. OptimInfo(
  826. name='mars',
  827. opt_class=Mars,
  828. description='Unleashing the Power of Variance Reduction for Training Large Models',
  829. has_betas=True,
  830. ),
  831. OptimInfo(
  832. name='muon',
  833. opt_class=Muon,
  834. description='MomentUm Orthogonalized by Newton-schulz with AdamW fallback for 1D params',
  835. has_momentum=True,
  836. has_eps=True,
  837. has_betas=True,
  838. ),
  839. OptimInfo(
  840. name='nmuon',
  841. opt_class=Muon,
  842. description='MomentUm Orthogonalized by Newton-schulz with Nesterov and NAdamW fallback for 1D params',
  843. has_momentum=True,
  844. has_eps=True,
  845. has_betas=True,
  846. defaults={'nesterov': True}
  847. ),
  848. OptimInfo(
  849. name='adamuon',
  850. opt_class=Muon,
  851. description='AdaMuon: Muon with adaptive second moment estimation on orthogonalized directions',
  852. has_momentum=True,
  853. has_eps=True,
  854. has_betas=True,
  855. defaults={'algo': 'adamuon'}
  856. ),
  857. OptimInfo(
  858. name='nadamuon',
  859. opt_class=Muon,
  860. description='AdaMuon with Nesterov momentum and NAdamW fallback for 1D params',
  861. has_momentum=True,
  862. has_eps=True,
  863. has_betas=True,
  864. defaults={'algo': 'adamuon', 'nesterov': True}
  865. ),
  866. OptimInfo(
  867. name='novograd',
  868. opt_class=NvNovoGrad,
  869. description='Normalized Adam with L2 norm gradient normalization',
  870. has_betas=True
  871. ),
  872. OptimInfo(
  873. name='rmsprop',
  874. opt_class=torch.optim.RMSprop,
  875. description='torch.optim.RMSprop, Root Mean Square Propagation',
  876. has_momentum=True,
  877. defaults={'alpha': 0.9}
  878. ),
  879. OptimInfo(
  880. name='rmsproptf',
  881. opt_class=RMSpropTF,
  882. description='TensorFlow-style RMSprop implementation, Root Mean Square Propagation',
  883. has_momentum=True,
  884. defaults={'alpha': 0.9}
  885. ),
  886. ]
  887. for opt in other_optimizers:
  888. registry.register(opt)
  889. registry.register_foreach_default('lion')
  890. def _register_apex_optimizers(registry: OptimizerRegistry) -> None:
  891. """Register APEX optimizers (lazy import)"""
  892. apex_optimizers = [
  893. OptimInfo(
  894. name='fusedsgd',
  895. opt_class='apex.optimizers.FusedSGD',
  896. description='NVIDIA APEX fused SGD implementation for faster training',
  897. has_eps=False,
  898. has_momentum=True,
  899. defaults={'nesterov': True}
  900. ),
  901. OptimInfo(
  902. name='fusedadam',
  903. opt_class='apex.optimizers.FusedAdam',
  904. description='NVIDIA APEX fused Adam implementation',
  905. has_betas=True,
  906. defaults={'adam_w_mode': False}
  907. ),
  908. OptimInfo(
  909. name='fusedadamw',
  910. opt_class='apex.optimizers.FusedAdam',
  911. description='NVIDIA APEX fused AdamW implementation',
  912. has_betas=True,
  913. defaults={'adam_w_mode': True}
  914. ),
  915. OptimInfo(
  916. name='fusedlamb',
  917. opt_class='apex.optimizers.FusedLAMB',
  918. description='NVIDIA APEX fused LAMB implementation',
  919. has_betas=True
  920. ),
  921. OptimInfo(
  922. name='fusednovograd',
  923. opt_class='apex.optimizers.FusedNovoGrad',
  924. description='NVIDIA APEX fused NovoGrad implementation',
  925. has_betas=True,
  926. defaults={'betas': (0.95, 0.98)}
  927. ),
  928. ]
  929. for opt in apex_optimizers:
  930. registry.register(opt)
  931. def _register_bnb_optimizers(registry: OptimizerRegistry) -> None:
  932. """Register bitsandbytes optimizers (lazy import)"""
  933. bnb_optimizers = [
  934. OptimInfo(
  935. name='bnbsgd',
  936. opt_class='bitsandbytes.optim.SGD',
  937. description='bitsandbytes SGD',
  938. has_eps=False,
  939. has_momentum=True,
  940. defaults={'nesterov': True}
  941. ),
  942. OptimInfo(
  943. name='bnbsgd8bit',
  944. opt_class='bitsandbytes.optim.SGD8bit',
  945. description='bitsandbytes 8-bit SGD with dynamic quantization',
  946. has_eps=False,
  947. has_momentum=True,
  948. defaults={'nesterov': True}
  949. ),
  950. OptimInfo(
  951. name='bnbadam',
  952. opt_class='bitsandbytes.optim.Adam',
  953. description='bitsandbytes Adam',
  954. has_betas=True
  955. ),
  956. OptimInfo(
  957. name='bnbadam8bit',
  958. opt_class='bitsandbytes.optim.Adam',
  959. description='bitsandbytes 8-bit Adam with dynamic quantization',
  960. has_betas=True
  961. ),
  962. OptimInfo(
  963. name='bnbadamw',
  964. opt_class='bitsandbytes.optim.AdamW',
  965. description='bitsandbytes AdamW',
  966. has_betas=True
  967. ),
  968. OptimInfo(
  969. name='bnbadamw8bit',
  970. opt_class='bitsandbytes.optim.AdamW',
  971. description='bitsandbytes 8-bit AdamW with dynamic quantization',
  972. has_betas=True
  973. ),
  974. OptimInfo(
  975. 'bnblion',
  976. 'bitsandbytes.optim.Lion',
  977. description='bitsandbytes Lion',
  978. has_eps=False,
  979. has_betas=True
  980. ),
  981. OptimInfo(
  982. 'bnblion8bit',
  983. 'bitsandbytes.optim.Lion8bit',
  984. description='bitsandbytes 8-bit Lion with dynamic quantization',
  985. has_eps=False,
  986. has_betas=True
  987. ),
  988. OptimInfo(
  989. 'bnbademamix',
  990. 'bitsandbytes.optim.AdEMAMix',
  991. description='bitsandbytes AdEMAMix',
  992. has_betas=True,
  993. num_betas=3,
  994. ),
  995. OptimInfo(
  996. 'bnbademamix8bit',
  997. 'bitsandbytes.optim.AdEMAMix8bit',
  998. description='bitsandbytes 8-bit AdEMAMix with dynamic quantization',
  999. has_betas=True,
  1000. num_betas=3,
  1001. ),
  1002. ]
  1003. for opt in bnb_optimizers:
  1004. registry.register(opt)
  1005. default_registry = OptimizerRegistry()
  1006. def _register_default_optimizers() -> None:
  1007. """Register all default optimizers to the global registry."""
  1008. # Register all optimizer groups
  1009. _register_sgd_variants(default_registry)
  1010. _register_adam_variants(default_registry)
  1011. _register_lamb_lars(default_registry)
  1012. _register_other_optimizers(default_registry)
  1013. _register_apex_optimizers(default_registry)
  1014. _register_bnb_optimizers(default_registry)
  1015. _register_cautious_optimizers(default_registry)
  1016. _register_corrected_decay_optimizers(default_registry)
  1017. # Register aliases
  1018. default_registry.register_alias('nesterov', 'sgd')
  1019. default_registry.register_alias('nesterovw', 'sgdw')
  1020. # Initialize default registry
  1021. _register_default_optimizers()
  1022. # Public API
  1023. def list_optimizers(
  1024. filter: Union[str, List[str]] = '',
  1025. exclude_filters: Optional[List[str]] = None,
  1026. with_description: bool = False,
  1027. ) -> List[Union[str, Tuple[str, str]]]:
  1028. """List available optimizer names, optionally filtered.
  1029. List all registered optimizers, with optional filtering using wildcard patterns.
  1030. Optimizers can be filtered using include and exclude patterns, and can optionally
  1031. return descriptions with each optimizer name.
  1032. Args:
  1033. filter: Wildcard style filter string or list of filter strings
  1034. (e.g., 'adam*' for all Adam variants, or ['adam*', '*8bit'] for
  1035. Adam variants and 8-bit optimizers). Empty string means no filtering.
  1036. exclude_filters: Optional list of wildcard patterns to exclude. For example,
  1037. ['*8bit', 'fused*'] would exclude 8-bit and fused implementations.
  1038. with_description: If True, returns tuples of (name, description) instead of
  1039. just names. Descriptions provide brief explanations of optimizer characteristics.
  1040. Returns:
  1041. If with_description is False:
  1042. List of optimizer names as strings (e.g., ['adam', 'adamw', ...])
  1043. If with_description is True:
  1044. List of tuples of (name, description) (e.g., [('adam', 'Adaptive Moment...'), ...])
  1045. Examples:
  1046. >>> list_optimizers()
  1047. ['adam', 'adamw', 'sgd', ...]
  1048. >>> list_optimizers(['la*', 'nla*']) # List lamb & lars
  1049. ['lamb', 'lambc', 'larc', 'lars', 'nlarc', 'nlars']
  1050. >>> list_optimizers('*adam*', exclude_filters=['bnb*', 'fused*']) # Exclude bnb & apex adam optimizers
  1051. ['adam', 'adamax', 'adamp', 'adamw', 'nadam', 'nadamw', 'radam']
  1052. >>> list_optimizers(with_description=True) # Get descriptions
  1053. [('adabelief', 'Adapts learning rate based on gradient prediction error'),
  1054. ('adadelta', 'torch.optim Adadelta, Adapts learning rates based on running windows of gradients'),
  1055. ('adafactor', 'Memory-efficient implementation of Adam with factored gradients'),
  1056. ...]
  1057. """
  1058. return default_registry.list_optimizers(filter, exclude_filters, with_description)
  1059. def get_optimizer_info(name: str) -> OptimInfo:
  1060. """Get the OptimInfo for an optimizer.
  1061. Args:
  1062. name: Name of the optimizer
  1063. Returns:
  1064. OptimInfo configuration
  1065. Raises:
  1066. ValueError: If optimizer is not found
  1067. """
  1068. return default_registry.get_optimizer_info(name)
  1069. def get_optimizer_class(
  1070. name: str,
  1071. bind_defaults: bool = True,
  1072. ) -> Union[OptimType, OptimizerCallable]:
  1073. """Get optimizer class by name with option to bind default arguments.
  1074. Retrieves the optimizer class or a partial function with default arguments bound.
  1075. This allows direct instantiation of optimizers with their default configurations
  1076. without going through the full factory.
  1077. Args:
  1078. name: Name of the optimizer to retrieve (e.g., 'adam', 'sgd')
  1079. bind_defaults: If True, returns a partial function with default arguments from OptimInfo bound.
  1080. If False, returns the raw optimizer class.
  1081. Returns:
  1082. If bind_defaults is False:
  1083. The optimizer class (e.g., torch.optim.Adam)
  1084. If bind_defaults is True:
  1085. A partial function with default arguments bound
  1086. Raises:
  1087. ValueError: If optimizer name is not found in registry
  1088. Examples:
  1089. >>> # Get SGD with nesterov momentum default
  1090. >>> SGD = get_optimizer_class('sgd') # nesterov=True bound
  1091. >>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
  1092. >>> # Get raw optimizer class
  1093. >>> SGD = get_optimizer_class('sgd')
  1094. >>> opt = SGD(model.parameters(), lr=1e-3, momentum=0.9)
  1095. """
  1096. return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
  1097. def create_optimizer_v2(
  1098. model_or_params: Union[nn.Module, ParamsT],
  1099. opt: str = 'sgd',
  1100. lr: Optional[float] = None,
  1101. weight_decay: float = 0.,
  1102. momentum: float = 0.9,
  1103. foreach: Optional[bool] = None,
  1104. filter_bias_and_bn: bool = True,
  1105. fallback_list: Collection[str] = (),
  1106. fallback_no_weight_decay: bool = False,
  1107. layer_decay: Optional[float] = None,
  1108. layer_decay_min_scale: float = 0.0,
  1109. layer_decay_no_opt_scale: Optional[float] = None,
  1110. param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
  1111. **kwargs: Any,
  1112. ) -> torch.optim.Optimizer:
  1113. """Create an optimizer instance via timm registry.
  1114. Creates and configures an optimizer with appropriate parameter groups and settings.
  1115. Supports automatic parameter group creation for weight decay and layer-wise learning
  1116. rates, as well as custom parameter grouping.
  1117. Args:
  1118. model_or_params: A PyTorch model or an iterable of parameters/parameter groups.
  1119. If a model is provided, parameters will be automatically extracted and grouped
  1120. based on the other arguments.
  1121. opt: Name of the optimizer to create (e.g., 'adam', 'adamw', 'sgd').
  1122. Use list_optimizers() to see available options.
  1123. lr: Learning rate. If None, will use the optimizer's default.
  1124. weight_decay: Weight decay factor. Will be used to create param groups if model_or_params is a model.
  1125. momentum: Momentum factor for optimizers that support it. Only used if the
  1126. chosen optimizer accepts a momentum parameter.
  1127. foreach: Enable/disable foreach (multi-tensor) implementation if available.
  1128. If None, will use optimizer-specific defaults.
  1129. filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have
  1130. weight decay applied. Only used when model_or_params is a model and
  1131. weight_decay > 0.
  1132. fallback_list: Collection of parameter name patterns to use fallback optimizer for
  1133. hybrid optimizers (e.g., AdamW for Muon). Supports wildcard matching.
  1134. fallback_no_weight_decay: If True, params in model's no_weight_decay() list will use
  1135. fallback optimizer for hybrid optimizers (e.g., AdamW for Muon).
  1136. layer_decay: Optional layer-wise learning rate decay factor. If provided,
  1137. learning rates will be scaled by layer_decay^(max_depth - layer_depth).
  1138. Only used when model_or_params is a model.
  1139. param_group_fn: Optional function to create custom parameter groups.
  1140. If provided, other parameter grouping options will be ignored.
  1141. **kwargs: Additional optimizer-specific arguments (e.g., betas for Adam).
  1142. Returns:
  1143. Configured optimizer instance.
  1144. Examples:
  1145. >>> # Basic usage with a model
  1146. >>> optimizer = create_optimizer_v2(model, 'adamw', lr=1e-3)
  1147. >>> # SGD with momentum and weight decay
  1148. >>> optimizer = create_optimizer_v2(
  1149. ... model, 'sgd', lr=0.1, momentum=0.9, weight_decay=1e-4
  1150. ... )
  1151. >>> # Adam with layer-wise learning rate decay
  1152. >>> optimizer = create_optimizer_v2(
  1153. ... model, 'adam', lr=1e-3, layer_decay=0.7
  1154. ... )
  1155. >>> # Custom parameter groups
  1156. >>> def group_fn(model):
  1157. ... return [
  1158. ... {'params': model.backbone.parameters(), 'lr': 1e-4},
  1159. ... {'params': model.head.parameters(), 'lr': 1e-3}
  1160. ... ]
  1161. >>> optimizer = create_optimizer_v2(
  1162. ... model, 'sgd', param_group_fn=group_fn
  1163. ... )
  1164. Note:
  1165. Parameter group handling precedence:
  1166. 1. If param_group_fn is provided, it will be used exclusively
  1167. 2. If layer_decay is provided, layer-wise groups will be created
  1168. 3. If weight_decay > 0 and filter_bias_and_bn is True, weight decay groups will be created
  1169. 4. Otherwise, all parameters will be in a single group
  1170. """
  1171. return default_registry.create_optimizer(
  1172. model_or_params,
  1173. opt=opt,
  1174. lr=lr,
  1175. weight_decay=weight_decay,
  1176. momentum=momentum,
  1177. foreach=foreach,
  1178. weight_decay_exclude_1d=filter_bias_and_bn,
  1179. fallback_list=fallback_list,
  1180. fallback_no_weight_decay=fallback_no_weight_decay,
  1181. layer_decay=layer_decay,
  1182. layer_decay_min_scale=layer_decay_min_scale,
  1183. layer_decay_no_opt_scale=layer_decay_no_opt_scale,
  1184. param_group_fn=param_group_fn,
  1185. **kwargs
  1186. )
  1187. def optimizer_kwargs(cfg):
  1188. """Convert argparse-style `cfg` object to kwargs for an optimizer factory."""
  1189. kwargs = {
  1190. 'opt': cfg.opt,
  1191. 'lr': cfg.lr,
  1192. 'weight_decay': cfg.weight_decay,
  1193. 'momentum': cfg.momentum,
  1194. }
  1195. if (eps := getattr(cfg, 'opt_eps', None)) is not None:
  1196. kwargs['eps'] = eps
  1197. if (betas := getattr(cfg, 'opt_betas', None)) is not None:
  1198. kwargs['betas'] = betas
  1199. if (layer_decay := getattr(cfg, 'layer_decay', None)) is not None:
  1200. kwargs['layer_decay'] = layer_decay
  1201. if (ld_min := getattr(cfg, 'layer_decay_min_scale', None)) is not None:
  1202. kwargs['layer_decay_min_scale'] = ld_min
  1203. if (ld_no_opt := getattr(cfg, 'layer_decay_no_opt_scale', None)) is not None:
  1204. kwargs['layer_decay_no_opt_scale'] = ld_no_opt
  1205. if (opt_args := getattr(cfg, 'opt_args', None)) is not None:
  1206. kwargs.update(opt_args)
  1207. if (foreach := getattr(cfg, 'opt_foreach', None)) is not None:
  1208. kwargs['foreach'] = foreach
  1209. return kwargs
  1210. def create_optimizer(
  1211. args,
  1212. model: Union[nn.Module, ParamsT],
  1213. filter_bias_and_bn: bool = True,
  1214. ) -> torch.optim.Optimizer:
  1215. """ Legacy optimizer factory for backwards compatibility.
  1216. NOTE: Use create_optimizer_v2 for new code.
  1217. """
  1218. return create_optimizer_v2(
  1219. model,
  1220. **optimizer_kwargs(cfg=args),
  1221. filter_bias_and_bn=filter_bias_and_bn,
  1222. )