lr_scheduler.py 101 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602
  1. # mypy: allow-untyped-defs
  2. r"""Learning Rate Scheduler."""
  3. from __future__ import annotations
  4. import math
  5. import types
  6. import warnings
  7. from bisect import bisect_right
  8. from collections import Counter
  9. from functools import partial, wraps
  10. from typing import Any, cast, Literal, SupportsFloat, TYPE_CHECKING, TypedDict
  11. from typing_extensions import override, Self
  12. from weakref import ref
  13. from torch import inf, Tensor
  14. from .optimizer import _to_scalar, Optimizer
  15. if TYPE_CHECKING:
  16. from collections.abc import Callable, Iterable, Sequence
  17. __all__ = [
  18. "LambdaLR",
  19. "MultiplicativeLR",
  20. "StepLR",
  21. "MultiStepLR",
  22. "ConstantLR",
  23. "LinearLR",
  24. "ExponentialLR",
  25. "SequentialLR",
  26. "CosineAnnealingLR",
  27. "ChainedScheduler",
  28. "ReduceLROnPlateau",
  29. "CyclicLR",
  30. "CosineAnnealingWarmRestarts",
  31. "OneCycleLR",
  32. "PolynomialLR",
  33. "LRScheduler",
  34. ]
  35. EPOCH_DEPRECATION_WARNING = (
  36. "The epoch parameter in `scheduler.step()` was not necessary and is being "
  37. "deprecated where possible. Please use `scheduler.step()` to step the "
  38. "scheduler. During the deprecation, if epoch is different from None, the "
  39. "closed form is used instead of the new chainable form, where available. "
  40. "Please open an issue if you are unable to replicate your use case: "
  41. "https://github.com/pytorch/pytorch/issues/new/choose."
  42. )
  43. def _format_param(name: str, optimizer: Optimizer, param):
  44. """Return correctly formatted lr/momentum for each param group."""
  45. def _copy(_param):
  46. return _param.clone() if isinstance(_param, Tensor) else _param
  47. if isinstance(param, (list, tuple)):
  48. if len(param) != len(optimizer.param_groups):
  49. raise ValueError(
  50. f"{name} must have the same length as optimizer.param_groups. "
  51. f"{name} has {len(param)} values, param_groups has {len(optimizer.param_groups)}."
  52. )
  53. else:
  54. param = [param] * len(optimizer.param_groups)
  55. return list(map(_copy, param))
  56. def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
  57. """Create a list containing group[key] for each optimizer param_group.
  58. Prevents aliasing when group[key] could be a Tensor.
  59. Raises a KeyError when group[key] does not exist.
  60. """
  61. return [
  62. group[key].clone() if isinstance(group[key], Tensor) else group[key]
  63. for group in optimizer.param_groups
  64. ]
  65. def _update_param_group_val(
  66. param_group: dict[str, Any], key: str, val: float | Tensor
  67. ) -> None:
  68. """Set param_group[key] to val without aliasing or assignment when they're
  69. both tensors. Raises a KeyError if param_group[key] does not exist.
  70. """
  71. if isinstance(param_group[key], Tensor):
  72. param_group[key].fill_(_to_scalar(val))
  73. else:
  74. param_group[key] = val
  75. class LRScheduler:
  76. r"""Base class for all learning rate schedulers.
  77. Subclasses implement :meth:`get_lr` and optionally override :meth:`step` to
  78. define scheduling behavior.
  79. Args:
  80. optimizer (Optimizer): The optimizer this scheduler will adjust the
  81. learning rates of.
  82. last_epoch (int): Index of the last epoch seen by the scheduler. Use
  83. ``-1`` (default) to initialize the scheduler. Only use a non-default
  84. value when restoring this scheduler from a saved checkpoint.
  85. .. warning::
  86. Initializing a scheduler overwrites its optimizer's
  87. ``param_group["lr"]``\s. When restoring a checkpoint, initialize the
  88. scheduler **before** calling your optimizer's
  89. :meth:`~torch.optim.Optimizer.load_state_dict` to avoid overwriting the
  90. loaded learning rates.
  91. """
  92. _get_lr_called_within_step: bool = False
  93. _is_initial: bool = False
  94. def __init__(
  95. self,
  96. optimizer: Optimizer,
  97. last_epoch: int = -1,
  98. ) -> None: # noqa: D107
  99. # Attach optimizer
  100. if not isinstance(optimizer, Optimizer):
  101. raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
  102. self.optimizer = optimizer
  103. # Initialize epoch and base learning rates
  104. if last_epoch == -1:
  105. for group in optimizer.param_groups:
  106. initial_lr = group["lr"]
  107. if isinstance(initial_lr, Tensor):
  108. initial_lr = initial_lr.clone()
  109. group.setdefault("initial_lr", initial_lr)
  110. else:
  111. for i, group in enumerate(optimizer.param_groups):
  112. if "initial_lr" not in group:
  113. raise KeyError(
  114. f"param 'initial_lr' is not specified in param_groups[{i}] when resuming scheduler with last_epoch >= 0.\n"
  115. "This typically happens when:\n"
  116. "1. You're trying to resume training from a checkpoint but haven't properly loaded the optimizer state\n"
  117. "2. You're using last_epoch >= 0 for a fresh training run (not recommended)"
  118. )
  119. self.base_lrs: list[float | Tensor] = _param_groups_val_list(
  120. optimizer, "initial_lr"
  121. )
  122. self.last_epoch = last_epoch
  123. # Following https://github.com/pytorch/pytorch/issues/20124
  124. # We would like to ensure that `lr_scheduler.step()` is called after
  125. # `optimizer.step()`
  126. def patch_track_step_called(opt: Optimizer):
  127. if hasattr(opt.step, "_wrapped_by_lr_sched"):
  128. # we've already patched
  129. return opt.step
  130. def wrap_step(step_fn):
  131. opt_ref = ref(self.optimizer)
  132. func = step_fn.__func__
  133. @wraps(func)
  134. def wrapper(*args, **kwargs):
  135. opt = opt_ref()
  136. opt._opt_called = True # type: ignore[union-attr]
  137. return func.__get__(opt, opt.__class__)(*args, **kwargs)
  138. wrapper._wrapped_by_lr_sched = True # type: ignore[attr-defined]
  139. return wrapper
  140. opt.step = wrap_step(opt.step) # type: ignore[method-assign]
  141. patch_track_step_called(self.optimizer)
  142. self._initial_step()
  143. def _initial_step(self) -> None:
  144. """Initialize step counts and perform a step."""
  145. self._step_count = 0
  146. with _initial_mode(self):
  147. self.step()
  148. def state_dict(self) -> dict[str, Any]:
  149. """Return the state of the scheduler as a :class:`dict`.
  150. It contains an entry for every variable in ``self.__dict__`` which
  151. is not the optimizer.
  152. """
  153. return {
  154. key: value for key, value in self.__dict__.items() if key != "optimizer"
  155. }
  156. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  157. """Load the scheduler's state.
  158. Args:
  159. state_dict (dict): scheduler state. Should be an object returned
  160. from a call to :meth:`state_dict`.
  161. """
  162. self.__dict__.update(state_dict)
  163. def get_last_lr(self) -> list[float | Tensor]:
  164. r"""Get the most recent learning rates computed by this scheduler.
  165. Returns:
  166. list[float | Tensor]: A :class:`list` of learning rates with entries
  167. for each of the optimizer's
  168. :attr:`~torch.optim.Optimizer.param_groups`, with the same types as
  169. their ``group["lr"]``\s.
  170. .. note::
  171. The returned :class:`~torch.Tensor`\s are copies, and never alias
  172. the optimizer's ``group["lr"]``\s.
  173. """
  174. # We always update self._last_lr with _param_groups_val_list, so it's a
  175. # .clone() of the group["lr"]s. If we didn't do this, the user could
  176. # corrupt their learning rates by modifying the outputs in place.
  177. return self._last_lr
  178. def get_lr(self) -> list[float | Tensor]:
  179. r"""Compute the next learning rate for each of the optimizer's
  180. :attr:`~torch.optim.Optimizer.param_groups`.
  181. Returns:
  182. list[float | Tensor]: A :class:`list` of learning rates for each of
  183. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  184. same types as their current ``group["lr"]``\s.
  185. .. note::
  186. If you're trying to inspect the most recent learning rate, use
  187. :meth:`get_last_lr()` instead.
  188. .. note::
  189. The returned :class:`~torch.Tensor`\s are copies, and never alias
  190. the optimizer's ``group["lr"]``\s.
  191. """
  192. raise NotImplementedError
  193. def step(self, epoch: int | None = None) -> None:
  194. """Step the scheduler.
  195. Args:
  196. epoch (int, optional):
  197. .. deprecated:: 1.4
  198. If provided, sets :attr:`last_epoch` to ``epoch`` and uses
  199. :meth:`_get_closed_form_lr` if it is available. This is not
  200. universally supported. Use :meth:`step` without arguments
  201. instead.
  202. .. note::
  203. Call this method after calling the optimizer's
  204. :meth:`~torch.optim.Optimizer.step`.
  205. """
  206. # Raise a warning if old pattern is detected
  207. # https://github.com/pytorch/pytorch/issues/20124
  208. if self._step_count == 1:
  209. if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"):
  210. warnings.warn(
  211. "Seems like `optimizer.step()` has been overridden after learning rate scheduler "
  212. "initialization. Please, make sure to call `optimizer.step()` before "
  213. "`lr_scheduler.step()`. See more details at "
  214. "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
  215. UserWarning,
  216. stacklevel=2,
  217. )
  218. # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
  219. elif not getattr(self.optimizer, "_opt_called", False):
  220. warnings.warn(
  221. "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
  222. "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
  223. "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
  224. "will result in PyTorch skipping the first value of the learning rate schedule. "
  225. "See more details at "
  226. "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
  227. UserWarning,
  228. stacklevel=2,
  229. )
  230. self._step_count += 1
  231. if epoch is not None:
  232. warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning, stacklevel=2)
  233. self._update_lr(epoch)
  234. def _update_lr(self, epoch: int | None = None) -> None:
  235. with _enable_get_lr_call(self):
  236. if epoch is None:
  237. self.last_epoch += 1
  238. values = self.get_lr()
  239. else:
  240. self.last_epoch = epoch
  241. if hasattr(self, "_get_closed_form_lr"):
  242. values = cast(list[float | Tensor], self._get_closed_form_lr())
  243. else:
  244. values = self.get_lr()
  245. for param_group, lr in zip(self.optimizer.param_groups, values, strict=True):
  246. _update_param_group_val(param_group, "lr", lr)
  247. self._last_lr: list[float | Tensor] = _param_groups_val_list(
  248. self.optimizer, "lr"
  249. )
  250. def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler) -> None:
  251. if not lr_scheduler._get_lr_called_within_step:
  252. warnings.warn(
  253. "To get the last learning rate computed by the scheduler, "
  254. "please use `get_last_lr()`.",
  255. UserWarning,
  256. stacklevel=2,
  257. )
  258. # Including _LRScheduler for backwards compatibility
  259. # Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
  260. class _LRScheduler(LRScheduler):
  261. pass
  262. class _enable_get_lr_call:
  263. def __init__(self, o: LRScheduler) -> None:
  264. self.o = o
  265. def __enter__(self) -> Self:
  266. self.o._get_lr_called_within_step = True
  267. return self
  268. def __exit__(self, type, value, traceback) -> None:
  269. self.o._get_lr_called_within_step = False
  270. class _initial_mode:
  271. def __init__(self, o: LRScheduler) -> None:
  272. self.o = o
  273. def __enter__(self):
  274. self.o._is_initial = True
  275. def __exit__(self, type, value, traceback):
  276. self.o._is_initial = False
  277. class LambdaLR(LRScheduler):
  278. """Sets the initial learning rate.
  279. The learning rate of each parameter group is set to the initial lr
  280. times a given function. When last_epoch=-1, sets initial lr as lr.
  281. Args:
  282. optimizer (Optimizer): Wrapped optimizer.
  283. lr_lambda (function or list): A function which computes a multiplicative
  284. factor given an integer parameter epoch, or a list of such
  285. functions, one for each group in optimizer.param_groups.
  286. last_epoch (int): The index of last epoch. Default: -1.
  287. Example:
  288. >>> # xdoctest: +SKIP
  289. >>> # Assuming optimizer has two groups.
  290. >>> num_epochs = 100
  291. >>> lambda1 = lambda epoch: epoch // 30
  292. >>> lambda2 = lambda epoch: 0.95**epoch
  293. >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
  294. >>> for epoch in range(num_epochs):
  295. >>> train(...)
  296. >>> validate(...)
  297. >>> scheduler.step()
  298. >>>
  299. >>> # Alternatively, you can use a single lambda function for all groups.
  300. >>> scheduler = LambdaLR(opt, lr_lambda=lambda epoch: epoch // 30)
  301. >>> for epoch in range(num_epochs):
  302. >>> train(...)
  303. >>> validate(...)
  304. >>> scheduler.step()
  305. .. image:: ../scripts/lr_scheduler_images/LambdaLR.png
  306. """
  307. def __init__(
  308. self,
  309. optimizer: Optimizer,
  310. lr_lambda: Callable[[int], float] | list[Callable[[int], float]],
  311. last_epoch: int = -1,
  312. ) -> None: # noqa: D107
  313. self.optimizer = optimizer
  314. self.lr_lambdas: list[Callable[[int], float]]
  315. if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
  316. self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
  317. else:
  318. if len(lr_lambda) != len(optimizer.param_groups):
  319. raise ValueError(
  320. f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
  321. )
  322. self.lr_lambdas = list(lr_lambda)
  323. super().__init__(optimizer, last_epoch)
  324. @override
  325. def state_dict(self) -> dict[str, Any]:
  326. """Return the state of the scheduler as a :class:`dict`.
  327. It contains an entry for every variable in ``self.__dict__`` which is not the optimizer.
  328. The learning rate lambda functions will only be saved if they are callable objects
  329. and not if they are functions or lambdas.
  330. When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
  331. """
  332. state_dict = {
  333. key: value
  334. for key, value in self.__dict__.items()
  335. if key not in ("optimizer", "lr_lambdas")
  336. }
  337. state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas)
  338. for idx, fn in enumerate(self.lr_lambdas):
  339. if not isinstance(fn, types.FunctionType):
  340. # pyrefly: ignore [unsupported-operation]
  341. state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
  342. return state_dict
  343. @override
  344. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  345. """Load the scheduler's state.
  346. When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
  347. Args:
  348. state_dict (dict): scheduler state. Should be an object returned
  349. from a call to :meth:`state_dict`.
  350. """
  351. lr_lambdas = state_dict.pop("lr_lambdas")
  352. self.__dict__.update(state_dict)
  353. # Restore state_dict keys in order to prevent side effects
  354. # https://github.com/pytorch/pytorch/issues/32756
  355. state_dict["lr_lambdas"] = lr_lambdas
  356. for idx, fn in enumerate(lr_lambdas):
  357. if fn is not None:
  358. self.lr_lambdas[idx].__dict__.update(fn)
  359. @override
  360. def get_lr(self) -> list[float | Tensor]:
  361. r"""Compute the next learning rate for each of the optimizer's
  362. :attr:`~torch.optim.Optimizer.param_groups`.
  363. Scales the :attr:`base_lrs` by the outputs of the :attr:`lr_lambdas` at
  364. :attr:`last_epoch`.
  365. Returns:
  366. list[float | Tensor]: A :class:`list` of learning rates for each of
  367. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  368. same types as their current ``group["lr"]``\s.
  369. .. note::
  370. If you're trying to inspect the most recent learning rate, use
  371. :meth:`get_last_lr()` instead.
  372. .. note::
  373. The returned :class:`~torch.Tensor`\s are copies, and never alias
  374. the optimizer's ``group["lr"]``\s.
  375. """
  376. _warn_get_lr_called_within_step(self)
  377. return [
  378. base_lr * lmbda(self.last_epoch)
  379. for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs, strict=True)
  380. ]
  381. class MultiplicativeLR(LRScheduler):
  382. """Multiply the learning rate of each parameter group by the factor given in the specified function.
  383. When last_epoch=-1, set initial lr as lr.
  384. Args:
  385. optimizer (Optimizer): Wrapped optimizer.
  386. lr_lambda (function or list): A function which computes a multiplicative
  387. factor given an integer parameter epoch, or a list of such
  388. functions, one for each group in optimizer.param_groups.
  389. last_epoch (int): The index of last epoch. Default: -1.
  390. Example:
  391. >>> # xdoctest: +SKIP
  392. >>> lmbda = lambda epoch: 0.95
  393. >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
  394. >>> for epoch in range(100):
  395. >>> train(...)
  396. >>> validate(...)
  397. >>> scheduler.step()
  398. .. image:: ../scripts/lr_scheduler_images/MultiplicativeLR.png
  399. """
  400. def __init__(
  401. self,
  402. optimizer: Optimizer,
  403. lr_lambda: Callable[[int], float] | list[Callable[[int], float]],
  404. last_epoch: int = -1,
  405. ) -> None: # noqa: D107
  406. self.optimizer = optimizer
  407. self.lr_lambdas: list[Callable[[int], float]]
  408. if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
  409. self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
  410. else:
  411. if len(lr_lambda) != len(optimizer.param_groups):
  412. raise ValueError(
  413. f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
  414. )
  415. self.lr_lambdas = list(lr_lambda)
  416. for lr_lambda in self.lr_lambdas:
  417. if not callable(lr_lambda):
  418. raise TypeError(
  419. f"lr_lambda should be a function, but got {type(lr_lambda).__name__}"
  420. )
  421. super().__init__(optimizer, last_epoch)
  422. @override
  423. def state_dict(self) -> dict[str, Any]:
  424. """Return the state of the scheduler as a :class:`dict`.
  425. It contains an entry for every variable in ``self.__dict__`` which
  426. is not the optimizer.
  427. The learning rate lambda functions will only be saved if they are callable objects
  428. and not if they are functions or lambdas.
  429. """
  430. state_dict = {
  431. key: value
  432. for key, value in self.__dict__.items()
  433. if key not in ("optimizer", "lr_lambdas")
  434. }
  435. state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas)
  436. for idx, fn in enumerate(self.lr_lambdas):
  437. if not isinstance(fn, types.FunctionType):
  438. # pyrefly: ignore [unsupported-operation]
  439. state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
  440. return state_dict
  441. @override
  442. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  443. """Load the scheduler's state.
  444. Args:
  445. state_dict (dict): scheduler state. Should be an object returned
  446. from a call to :meth:`state_dict`.
  447. """
  448. lr_lambdas = state_dict.pop("lr_lambdas")
  449. self.__dict__.update(state_dict)
  450. # Restore state_dict keys in order to prevent side effects
  451. # https://github.com/pytorch/pytorch/issues/32756
  452. state_dict["lr_lambdas"] = lr_lambdas
  453. for idx, fn in enumerate(lr_lambdas):
  454. if fn is not None:
  455. self.lr_lambdas[idx].__dict__.update(fn)
  456. @override
  457. def get_lr(self) -> list[float | Tensor]:
  458. r"""Compute the next learning rate for each of the optimizer's
  459. :attr:`~torch.optim.Optimizer.param_groups`.
  460. Scales the current ``group["lr"]``\s in each of the optimizer's
  461. :attr:`~torch.optim.Optimizer.param_groups` by the outputs of the
  462. :attr:`lr_lambdas` at :attr:`last_epoch`.
  463. Returns:
  464. list[float | Tensor]: A :class:`list` of learning rates for each of
  465. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  466. same types as their current ``group["lr"]``\s.
  467. .. note::
  468. If you're trying to inspect the most recent learning rate, use
  469. :meth:`get_last_lr()` instead.
  470. .. note::
  471. The returned :class:`~torch.Tensor`\s are copies, and never alias
  472. the optimizer's ``group["lr"]``\s.
  473. """
  474. _warn_get_lr_called_within_step(self)
  475. if not self._is_initial:
  476. return [
  477. group["lr"] * lmbda(self.last_epoch)
  478. for lmbda, group in zip(
  479. self.lr_lambdas, self.optimizer.param_groups, strict=True
  480. )
  481. ]
  482. else:
  483. return _param_groups_val_list(self.optimizer, "lr")
  484. class StepLR(LRScheduler):
  485. """Decays the learning rate of each parameter group by gamma every step_size epochs.
  486. Notice that such decay can happen simultaneously with other changes to the learning rate
  487. from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
  488. Args:
  489. optimizer (Optimizer): Wrapped optimizer.
  490. step_size (int): Period of learning rate decay.
  491. gamma (float): Multiplicative factor of learning rate decay.
  492. Default: 0.1.
  493. last_epoch (int): The index of last epoch. Default: -1.
  494. Example:
  495. >>> # xdoctest: +SKIP
  496. >>> # Assuming optimizer uses lr = 0.05 for all groups
  497. >>> # lr = 0.05 if epoch < 30
  498. >>> # lr = 0.005 if 30 <= epoch < 60
  499. >>> # lr = 0.0005 if 60 <= epoch < 90
  500. >>> # ...
  501. >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
  502. >>> for epoch in range(100):
  503. >>> train(...)
  504. >>> validate(...)
  505. >>> scheduler.step()
  506. .. image:: ../scripts/lr_scheduler_images/StepLR.png
  507. """
  508. def __init__(
  509. self,
  510. optimizer: Optimizer,
  511. step_size: int,
  512. gamma: float = 0.1,
  513. last_epoch: int = -1,
  514. ) -> None: # noqa: D107
  515. self.step_size = step_size
  516. self.gamma = gamma
  517. super().__init__(optimizer, last_epoch)
  518. @override
  519. def get_lr(self) -> list[float | Tensor]:
  520. r"""Compute the next learning rate for each of the optimizer's
  521. :attr:`~torch.optim.Optimizer.param_groups`.
  522. If the current epoch is a non-zero multiple of :attr:`step_size`, we
  523. scale the current ``group["lr"]``\s in the optimizer's
  524. :attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
  525. Returns:
  526. list[float | Tensor]: A :class:`list` of learning rates for each of
  527. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  528. same types as their current ``group["lr"]``\s.
  529. .. note::
  530. If you're trying to inspect the most recent learning rate, use
  531. :meth:`get_last_lr()` instead.
  532. .. note::
  533. The returned :class:`~torch.Tensor`\s are copies, and never alias
  534. the optimizer's ``group["lr"]``\s.
  535. """
  536. _warn_get_lr_called_within_step(self)
  537. if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
  538. return _param_groups_val_list(self.optimizer, "lr")
  539. return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
  540. def _get_closed_form_lr(self) -> list[float | Tensor]:
  541. r"""Compute learning rates for each of the optimizer's
  542. :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
  543. a closed-form formula.
  544. Uses :attr:`base_lrs` to compute learning rates. This method is called
  545. when an epoch is passed to :meth:`step`.
  546. Returns:
  547. list[float | Tensor]: A :class:`list` of learning rates for each of
  548. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  549. same types as their current ``group["lr"]``\s.
  550. """
  551. return [
  552. base_lr * self.gamma ** (self.last_epoch // self.step_size)
  553. for base_lr in self.base_lrs
  554. ]
  555. class MultiStepLR(LRScheduler):
  556. """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones.
  557. Notice that such decay can happen simultaneously with other changes to the learning rate
  558. from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
  559. Args:
  560. optimizer (Optimizer): Wrapped optimizer.
  561. milestones (list): List of epoch indices. Must be increasing.
  562. gamma (float): Multiplicative factor of learning rate decay.
  563. Default: 0.1.
  564. last_epoch (int): The index of last epoch. Default: -1.
  565. Example:
  566. >>> # xdoctest: +SKIP
  567. >>> # Assuming optimizer uses lr = 0.05 for all groups
  568. >>> # lr = 0.05 if epoch < 30
  569. >>> # lr = 0.005 if 30 <= epoch < 80
  570. >>> # lr = 0.0005 if epoch >= 80
  571. >>> scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
  572. >>> for epoch in range(100):
  573. >>> train(...)
  574. >>> validate(...)
  575. >>> scheduler.step()
  576. .. image:: ../scripts/lr_scheduler_images/MultiStepLR.png
  577. """
  578. def __init__(
  579. self,
  580. optimizer: Optimizer,
  581. milestones: Iterable[int],
  582. gamma: float = 0.1,
  583. last_epoch: int = -1,
  584. ) -> None: # noqa: D107
  585. self.milestones = Counter(milestones)
  586. self.gamma = gamma
  587. super().__init__(optimizer, last_epoch)
  588. @override
  589. def get_lr(self) -> list[float | Tensor]:
  590. r"""Compute the next learning rate for each of the optimizer's
  591. :attr:`~torch.optim.Optimizer.param_groups`.
  592. If the current epoch is in :attr:`milestones`, decays the
  593. ``group["lr"]``\s in the optimizer's
  594. :attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
  595. Returns:
  596. list[float | Tensor]: A :class:`list` of learning rates for each of
  597. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  598. same types as their current ``group["lr"]``\s.
  599. .. note::
  600. If you're trying to inspect the most recent learning rate, use
  601. :meth:`get_last_lr()` instead.
  602. .. note::
  603. The returned :class:`~torch.Tensor`\s are copies, and never alias
  604. the optimizer's ``group["lr"]``\s.
  605. .. note::
  606. If the current epoch appears in :attr:`milestones` ``n`` times, we
  607. scale by :attr:`gamma` to the power of ``n``
  608. """
  609. _warn_get_lr_called_within_step(self)
  610. if self.last_epoch not in self.milestones:
  611. return _param_groups_val_list(self.optimizer, "lr")
  612. return [
  613. group["lr"] * self.gamma ** self.milestones[self.last_epoch]
  614. for group in self.optimizer.param_groups
  615. ]
  616. def _get_closed_form_lr(self):
  617. r"""Compute learning rates for each of the optimizer's
  618. :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
  619. a closed-form formula.
  620. Uses :attr:`base_lrs` to compute learning rates. This method is called
  621. when an epoch is passed to :meth:`step`.
  622. Returns:
  623. list[float | Tensor]: A :class:`list` of learning rates for each of
  624. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  625. same types as their current ``group["lr"]``\s.
  626. """
  627. milestones = sorted(self.milestones.elements())
  628. return [
  629. base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
  630. for base_lr in self.base_lrs
  631. ]
  632. class ConstantLR(LRScheduler):
  633. """Multiply the learning rate of each parameter group by a small constant factor.
  634. The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters.
  635. Notice that such multiplication of the small constant factor can
  636. happen simultaneously with other changes to the learning rate from outside this scheduler.
  637. When last_epoch=-1, sets initial lr as lr.
  638. Args:
  639. optimizer (Optimizer): Wrapped optimizer.
  640. factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
  641. total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor.
  642. Default: 5.
  643. last_epoch (int): The index of the last epoch. Default: -1.
  644. Example:
  645. >>> # xdoctest: +SKIP
  646. >>> # Assuming optimizer uses lr = 0.05 for all groups
  647. >>> # lr = 0.025 if epoch == 0
  648. >>> # lr = 0.025 if epoch == 1
  649. >>> # lr = 0.025 if epoch == 2
  650. >>> # lr = 0.025 if epoch == 3
  651. >>> # ...
  652. >>> # lr = 0.05 if epoch >= 40
  653. >>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=40)
  654. >>> for epoch in range(100):
  655. >>> train(...)
  656. >>> validate(...)
  657. >>> scheduler.step()
  658. .. image:: ../scripts/lr_scheduler_images/ConstantLR.png
  659. """
  660. def __init__(
  661. self,
  662. optimizer: Optimizer,
  663. factor: float = 1.0 / 3,
  664. total_iters: int = 5,
  665. last_epoch: int = -1,
  666. ) -> None: # noqa: D107
  667. if factor > 1.0 or factor < 0:
  668. raise ValueError(
  669. "Constant multiplicative factor expected to be between 0 and 1."
  670. )
  671. self.factor = factor
  672. self.total_iters = total_iters
  673. super().__init__(optimizer, last_epoch)
  674. @override
  675. def get_lr(self) -> list[float | Tensor]:
  676. r"""Compute the next learning rate for each of the optimizer's
  677. :attr:`~torch.optim.Optimizer.param_groups`.
  678. When :attr:`last_epoch` is 0, this method scales the ``group["lr"]``\s
  679. in each of the optimizer's :attr:`~torch.optim.Optimizer.param_groups`
  680. by :attr:`factor`. Once :attr:`total_iters` is reached, it undoes this,
  681. scaling by ``1 / factor``.
  682. Returns:
  683. list[float | Tensor]: A :class:`list` of learning rates for each of
  684. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  685. same types as their current ``group["lr"]``\s.
  686. .. note::
  687. If you're trying to inspect the most recent learning rate, use
  688. :meth:`get_last_lr()` instead.
  689. .. note::
  690. The returned :class:`~torch.Tensor`\s are copies, and never alias
  691. the optimizer's ``group["lr"]``\s.
  692. """
  693. _warn_get_lr_called_within_step(self)
  694. if self.last_epoch == 0:
  695. return [group["lr"] * self.factor for group in self.optimizer.param_groups]
  696. if self.last_epoch != self.total_iters:
  697. return _param_groups_val_list(self.optimizer, "lr")
  698. return [
  699. group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups
  700. ]
  701. def _get_closed_form_lr(self):
  702. r"""Compute learning rates for each of the optimizer's
  703. :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
  704. a closed-form formula.
  705. Uses :attr:`base_lrs` to compute learning rates. This method is called
  706. when an epoch is passed to :meth:`step`.
  707. Returns:
  708. list[float | Tensor]: A :class:`list` of learning rates for each of
  709. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  710. same types as their current ``group["lr"]``\s.
  711. """
  712. return [
  713. base_lr
  714. * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
  715. for base_lr in self.base_lrs
  716. ]
  717. class LinearLR(LRScheduler):
  718. """Decays the learning rate of each parameter group by linearly changing small multiplicative factor.
  719. The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters.
  720. Notice that such decay can happen simultaneously with other changes to the learning rate
  721. from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
  722. Args:
  723. optimizer (Optimizer): Wrapped optimizer.
  724. start_factor (float): The number we multiply learning rate in the first epoch.
  725. The multiplication factor changes towards end_factor in the following epochs.
  726. Default: 1./3.
  727. end_factor (float): The number we multiply learning rate at the end of linear changing
  728. process. Default: 1.0.
  729. total_iters (int): The number of iterations that multiplicative factor reaches to 1.
  730. Default: 5.
  731. last_epoch (int): The index of the last epoch. Default: -1.
  732. Example:
  733. >>> # xdoctest: +SKIP
  734. >>> # Assuming optimizer uses lr = 0.05 for all groups
  735. >>> # lr = 0.003687 if epoch == 0
  736. >>> # lr = 0.004875 if epoch == 1
  737. >>> # lr = 0.006062 if epoch == 2
  738. >>> # lr = 0.00725 if epoch == 3
  739. >>> # ...
  740. >>> # lr = 0.05 if epoch >= 40
  741. >>> scheduler = LinearLR(optimizer, start_factor=0.05, total_iters=40)
  742. >>> for epoch in range(100):
  743. >>> train(...)
  744. >>> validate(...)
  745. >>> scheduler.step()
  746. .. image:: ../scripts/lr_scheduler_images/LinearLR.png
  747. """
  748. def __init__(
  749. self,
  750. optimizer: Optimizer,
  751. start_factor: float = 1.0 / 3,
  752. end_factor: float = 1.0,
  753. total_iters: int = 5,
  754. last_epoch: int = -1,
  755. ) -> None: # noqa: D107
  756. if start_factor > 1.0 or start_factor <= 0:
  757. raise ValueError(
  758. "Starting multiplicative factor expected to be greater than 0 and less or equal to 1."
  759. )
  760. if end_factor > 1.0 or end_factor < 0:
  761. raise ValueError(
  762. "Ending multiplicative factor expected to be between 0 and 1."
  763. )
  764. self.start_factor = start_factor
  765. self.end_factor = end_factor
  766. self.total_iters = total_iters
  767. super().__init__(optimizer, last_epoch)
  768. @override
  769. def get_lr(self) -> list[float | Tensor]:
  770. r"""Compute the next learning rate for each of the optimizer's
  771. :attr:`~torch.optim.Optimizer.param_groups`.
  772. Scales the ``group["lr"]``\s in the optimizer's
  773. :attr:`~torch.optim.Optimizer.param_groups` such that successive steps
  774. interpolate linearly from :attr:`start_factor` up to :attr:`end_factor`
  775. across :attr:`total_iters` steps.
  776. Returns:
  777. list[float | Tensor]: A :class:`list` of learning rates for each of
  778. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  779. same types as their current ``group["lr"]``\s.
  780. .. note::
  781. If you're trying to inspect the most recent learning rate, use
  782. :meth:`get_last_lr()` instead.
  783. .. note::
  784. The returned :class:`~torch.Tensor`\s are copies, and never alias
  785. the optimizer's ``group["lr"]``\s.
  786. """
  787. _warn_get_lr_called_within_step(self)
  788. if self.last_epoch == 0:
  789. return [
  790. group["lr"] * self.start_factor for group in self.optimizer.param_groups
  791. ]
  792. if self._is_initial or self.last_epoch > self.total_iters:
  793. return _param_groups_val_list(self.optimizer, "lr")
  794. return [
  795. group["lr"]
  796. * (
  797. 1.0
  798. + (self.end_factor - self.start_factor)
  799. / (
  800. self.total_iters * self.start_factor
  801. + (self.last_epoch - 1) * (self.end_factor - self.start_factor)
  802. )
  803. )
  804. for group in self.optimizer.param_groups
  805. ]
  806. def _get_closed_form_lr(self):
  807. r"""Compute learning rates for each of the optimizer's
  808. :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
  809. a closed-form formula.
  810. Uses :attr:`base_lrs` to compute learning rates. This method is called
  811. when an epoch is passed to :meth:`step`.
  812. Returns:
  813. list[float | Tensor]: A :class:`list` of learning rates for each of
  814. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  815. same types as their current ``group["lr"]``\s.
  816. """
  817. return [
  818. base_lr
  819. * (
  820. self.start_factor
  821. + (self.end_factor - self.start_factor)
  822. * min(self.total_iters, self.last_epoch)
  823. / self.total_iters
  824. )
  825. for base_lr in self.base_lrs
  826. ]
  827. class ExponentialLR(LRScheduler):
  828. """Decays the learning rate of each parameter group by gamma every epoch.
  829. When last_epoch=-1, sets initial lr as lr.
  830. Args:
  831. optimizer (Optimizer): Wrapped optimizer.
  832. gamma (float): Multiplicative factor of learning rate decay.
  833. last_epoch (int): The index of last epoch. Default: -1.
  834. Example:
  835. >>> # xdoctest: +SKIP
  836. >>> scheduler = ExponentialLR(optimizer, gamma=0.95)
  837. >>> for epoch in range(100):
  838. >>> train(...)
  839. >>> validate(...)
  840. >>> scheduler.step()
  841. .. image:: ../scripts/lr_scheduler_images/ExponentialLR.png
  842. """
  843. def __init__(
  844. self,
  845. optimizer: Optimizer,
  846. gamma: float,
  847. last_epoch: int = -1,
  848. ) -> None: # noqa: D107
  849. self.gamma = gamma
  850. super().__init__(optimizer, last_epoch)
  851. @override
  852. def get_lr(self) -> list[float | Tensor]:
  853. r"""Compute the next learning rate for each of the optimizer's
  854. :attr:`~torch.optim.Optimizer.param_groups`.
  855. Multiplies the current ``group["lr"]``\s in the optimizer's
  856. :attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
  857. Returns:
  858. list[float | Tensor]: A :class:`list` of learning rates for each of
  859. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  860. same types as their current ``group["lr"]``\s.
  861. .. note::
  862. If you're trying to inspect the most recent learning rate, use
  863. :meth:`get_last_lr()` instead.
  864. .. note::
  865. The returned :class:`~torch.Tensor`\s are copies, and never alias
  866. the optimizer's ``group["lr"]``\s.
  867. """
  868. _warn_get_lr_called_within_step(self)
  869. # when loading from a checkpoint, we don't want _initial_step (called from the constructor)
  870. # to update the lr one more step ahead of itself.
  871. if self._is_initial:
  872. return _param_groups_val_list(self.optimizer, "lr")
  873. return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
  874. def _get_closed_form_lr(self):
  875. r"""Compute learning rates for each of the optimizer's
  876. :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
  877. a closed-form formula.
  878. Uses :attr:`base_lrs` to compute learning rates. This method is called
  879. when an epoch is passed to :meth:`step`.
  880. Returns:
  881. list[float | Tensor]: A :class:`list` of learning rates for each of
  882. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  883. same types as their current ``group["lr"]``\s.
  884. """
  885. return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
  886. class SequentialLR(LRScheduler):
  887. """Contains a list of schedulers expected to be called sequentially during the optimization process.
  888. Specifically, the schedulers will be called according to the milestone points, which should provide exact
  889. intervals by which each scheduler should be called at a given epoch.
  890. Args:
  891. optimizer (Optimizer): Wrapped optimizer.
  892. schedulers (list): List of chained schedulers.
  893. milestones (list): List of integers that reflects milestone points.
  894. last_epoch (int): The index of last epoch. Default: -1.
  895. Example:
  896. >>> # xdoctest: +SKIP
  897. >>> # Assuming optimizer uses lr = 0.05 for all groups
  898. >>> # lr = 0.005 if epoch == 0
  899. >>> # lr = 0.005 if epoch == 1
  900. >>> # lr = 0.005 if epoch == 2
  901. >>> # ...
  902. >>> # lr = 0.05 if epoch == 20
  903. >>> # lr = 0.045 if epoch == 21
  904. >>> # lr = 0.0405 if epoch == 22
  905. >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
  906. >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
  907. >>> scheduler = SequentialLR(
  908. ... optimizer,
  909. ... schedulers=[scheduler1, scheduler2],
  910. ... milestones=[20],
  911. ... )
  912. >>> for epoch in range(100):
  913. >>> train(...)
  914. >>> validate(...)
  915. >>> scheduler.step()
  916. .. image:: ../scripts/lr_scheduler_images/SequentialLR.png
  917. """
  918. def __init__(
  919. self,
  920. optimizer: Optimizer,
  921. schedulers: list[LRScheduler],
  922. milestones: list[int],
  923. last_epoch: int = -1,
  924. ) -> None: # noqa: D107
  925. if len(schedulers) < 1:
  926. raise ValueError(
  927. f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler."
  928. )
  929. for scheduler_idx, scheduler in enumerate(schedulers):
  930. if not hasattr(scheduler, "optimizer"):
  931. raise TypeError(
  932. f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute."
  933. )
  934. if isinstance(scheduler, ReduceLROnPlateau):
  935. raise ValueError(
  936. f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it "
  937. "requires additional kwargs to be specified when calling `step`, "
  938. f"but got one at index {scheduler_idx} in the given schedulers sequence."
  939. )
  940. if optimizer != scheduler.optimizer:
  941. raise ValueError(
  942. f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but "
  943. f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, "
  944. f"which is different from {optimizer.__class__.__name__}."
  945. )
  946. if len(milestones) != len(schedulers) - 1:
  947. raise ValueError(
  948. "Sequential Schedulers expects number of schedulers provided to be one more "
  949. f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the "
  950. f"number of milestones to be equal to {len(milestones)}"
  951. )
  952. self._schedulers = schedulers
  953. self._milestones = milestones
  954. self.last_epoch = last_epoch + 1
  955. self.optimizer = optimizer
  956. # Reset learning rates back to initial values
  957. for group in self.optimizer.param_groups:
  958. _update_param_group_val(group, "lr", group["initial_lr"])
  959. # "Undo" the step performed by other schedulers
  960. self.recursive_undo()
  961. # Perform the initial step for only the first scheduler
  962. self._schedulers[0]._initial_step()
  963. self._last_lr = schedulers[0].get_last_lr()
  964. def recursive_undo(self, sched=None) -> None:
  965. """
  966. Recursively undo any step performed by the initialisation of
  967. schedulers.
  968. """
  969. scheds = self if sched is None else sched
  970. if hasattr(scheds, "_schedulers"):
  971. for s in scheds._schedulers:
  972. self.recursive_undo(s)
  973. elif hasattr(scheds, "last_epoch"):
  974. scheds.last_epoch -= 1
  975. def step(self) -> None: # type: ignore[override]
  976. """Perform a step."""
  977. self.last_epoch += 1
  978. idx = bisect_right(self._milestones, self.last_epoch)
  979. scheduler = self._schedulers[idx]
  980. if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
  981. scheduler._update_lr(0)
  982. else:
  983. scheduler.step()
  984. self._last_lr = scheduler.get_last_lr()
  985. @override
  986. def state_dict(self) -> dict[str, Any]:
  987. """Return the state of the scheduler as a :class:`dict`.
  988. It contains an entry for every variable in ``self.__dict__`` which
  989. is not the optimizer.
  990. The wrapped scheduler states will also be saved.
  991. """
  992. state_dict = {
  993. key: value
  994. for key, value in self.__dict__.items()
  995. if key not in ("optimizer", "_schedulers")
  996. }
  997. state_dict["_schedulers"] = [None] * len(self._schedulers)
  998. for idx, s in enumerate(self._schedulers):
  999. # pyrefly: ignore [unsupported-operation]
  1000. state_dict["_schedulers"][idx] = s.state_dict()
  1001. return state_dict
  1002. @override
  1003. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  1004. """Load the scheduler's state.
  1005. Args:
  1006. state_dict (dict): scheduler state. Should be an object returned
  1007. from a call to :meth:`state_dict`.
  1008. """
  1009. _schedulers = state_dict.pop("_schedulers")
  1010. self.__dict__.update(state_dict)
  1011. # Restore state_dict keys in order to prevent side effects
  1012. # https://github.com/pytorch/pytorch/issues/32756
  1013. state_dict["_schedulers"] = _schedulers
  1014. for idx, s in enumerate(_schedulers):
  1015. self._schedulers[idx].load_state_dict(s)
  1016. class PolynomialLR(LRScheduler):
  1017. """Decays the learning rate of each parameter group using a polynomial function in the given total_iters.
  1018. When last_epoch=-1, sets initial lr as lr.
  1019. Args:
  1020. optimizer (Optimizer): Wrapped optimizer.
  1021. total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5.
  1022. power (float): The power of the polynomial. Default: 1.0.
  1023. Example:
  1024. >>> # xdoctest: +SKIP("undefined vars")
  1025. >>> # Assuming optimizer uses lr = 0.05 for all groups
  1026. >>> # lr = 0.0490 if epoch == 0
  1027. >>> # lr = 0.0481 if epoch == 1
  1028. >>> # lr = 0.0472 if epoch == 2
  1029. >>> # ...
  1030. >>> # lr = 0.0 if epoch >= 50
  1031. >>> scheduler = PolynomialLR(optimizer, total_iters=50, power=0.9)
  1032. >>> for epoch in range(100):
  1033. >>> train(...)
  1034. >>> validate(...)
  1035. >>> scheduler.step()
  1036. .. image:: ../scripts/lr_scheduler_images/PolynomialLR.png
  1037. """
  1038. def __init__(
  1039. self,
  1040. optimizer: Optimizer,
  1041. total_iters: int = 5,
  1042. power: float = 1.0,
  1043. last_epoch: int = -1,
  1044. ) -> None: # noqa: D107
  1045. self.total_iters = total_iters
  1046. self.power = power
  1047. super().__init__(optimizer, last_epoch)
  1048. @override
  1049. def get_lr(self) -> list[float | Tensor]:
  1050. r"""Compute the next learning rate for each of the optimizer's
  1051. :attr:`~torch.optim.Optimizer.param_groups`.
  1052. Scales the ``group["lr"]``\s in the optimizer's
  1053. :attr:`~torch.optim.Optimizer.param_groups` such that the learning rates
  1054. follow
  1055. .. math::
  1056. \texttt{base\_lr} \cdot \left(1 - \frac{\texttt{last\_epoch}}
  1057. {\texttt{total\_iters}} \right)^\texttt{power}
  1058. Returns the current learning rates unchanged after :attr:`total_iters`
  1059. is reached.
  1060. Returns:
  1061. list[float | Tensor]: A :class:`list` of learning rates for each of
  1062. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  1063. same types as their current ``group["lr"]``\s.
  1064. .. note::
  1065. If you're trying to inspect the most recent learning rate, use
  1066. :meth:`get_last_lr()` instead.
  1067. .. note::
  1068. The returned :class:`~torch.Tensor`\s are copies, and never alias
  1069. the optimizer's ``group["lr"]``\s.
  1070. """
  1071. _warn_get_lr_called_within_step(self)
  1072. if self._is_initial or self.last_epoch > self.total_iters:
  1073. return _param_groups_val_list(self.optimizer, "lr")
  1074. decay_factor = (
  1075. (1.0 - self.last_epoch / self.total_iters)
  1076. / (1.0 - (self.last_epoch - 1) / self.total_iters)
  1077. ) ** self.power
  1078. return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
  1079. def _get_closed_form_lr(self) -> list[float | Tensor]:
  1080. r"""Compute learning rates for each of the optimizer's
  1081. :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
  1082. a closed-form formula.
  1083. Uses :attr:`base_lrs` to compute learning rates. This method is called
  1084. when an epoch is passed to :meth:`step`.
  1085. Returns:
  1086. list[float | Tensor]: A :class:`list` of learning rates for each of
  1087. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  1088. same types as their current ``group["lr"]``\s.
  1089. """
  1090. return [
  1091. (
  1092. base_lr
  1093. * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters)
  1094. ** self.power
  1095. )
  1096. for base_lr in self.base_lrs
  1097. ]
  1098. class CosineAnnealingLR(LRScheduler):
  1099. r"""
  1100. Set the learning rate of each parameter group using a cosine annealing schedule.
  1101. The learning rate is updated recursively using:
  1102. .. math::
  1103. \eta_{t+1} = \eta_{\min} + (\eta_t - \eta_{\min}) \cdot
  1104. \frac{1 + \cos\left(\frac{(T_{cur}+1) \pi}{T_{max}}\right)}
  1105. {1 + \cos\left(\frac{T_{cur} \pi}{T_{max}}\right)}
  1106. This implements a recursive approximation of the closed-form schedule proposed in
  1107. `SGDR: Stochastic Gradient Descent with Warm Restarts`_:
  1108. .. math::
  1109. \eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min}) \left(
  1110. 1 + \cos\left(\frac{T_{cur} \pi}{T_{max}}\right) \right)
  1111. where:
  1112. - :math:`\eta_t` is the learning rate at step :math:`t`
  1113. - :math:`T_{cur}` is the number of epochs since the last restart
  1114. - :math:`T_{max}` is the maximum number of epochs in a cycle
  1115. Note:
  1116. Although SGDR includes periodic restarts, this implementation performs cosine annealing
  1117. **without restarts**, so :math:`T_{cur} = t` and increases monotonically with each call
  1118. to :meth:`step`.
  1119. Args:
  1120. optimizer (Optimizer): Wrapped optimizer.
  1121. T_max (int): Maximum number of iterations.
  1122. eta_min (float): Minimum learning rate. Default: 0.
  1123. last_epoch (int): The index of the last epoch. Default: -1.
  1124. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
  1125. https://arxiv.org/abs/1608.03983
  1126. Example:
  1127. >>> # xdoctest: +SKIP
  1128. >>> num_epochs = 100
  1129. >>> scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
  1130. >>> for epoch in range(num_epochs):
  1131. >>> train(...)
  1132. >>> validate(...)
  1133. >>> scheduler.step()
  1134. .. image:: ../scripts/lr_scheduler_images/CosineAnnealingLR.png
  1135. """
  1136. def __init__(
  1137. self,
  1138. optimizer: Optimizer,
  1139. T_max: int,
  1140. eta_min: float = 0.0,
  1141. last_epoch: int = -1,
  1142. ) -> None: # noqa: D107
  1143. self.T_max = T_max
  1144. self.eta_min = eta_min
  1145. super().__init__(optimizer, last_epoch)
  1146. @override
  1147. def get_lr(self) -> list[float | Tensor]:
  1148. r"""Compute the next learning rate for each of the optimizer's
  1149. :attr:`~torch.optim.Optimizer.param_groups`.
  1150. Scales the ``group["lr"]``\s in the optimizer's
  1151. :attr:`~torch.optim.Optimizer.param_groups` such that their learning
  1152. rates approximate
  1153. .. math::
  1154. \texttt{eta\_min} + \frac{1}{2} (\texttt{base\_lr} -
  1155. \texttt{eta\_min}) \left(1 + \cos\left(\pi \cdot
  1156. \frac{\texttt{last\_epoch}}{\texttt{T\_max}}\right) \right)
  1157. Returns:
  1158. list[float | Tensor]: A :class:`list` of learning rates for each of
  1159. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  1160. same types as their current ``group["lr"]``\s.
  1161. .. note::
  1162. If you're trying to inspect the most recent learning rate, use
  1163. :meth:`get_last_lr()` instead.
  1164. .. note::
  1165. The returned :class:`~torch.Tensor`\s are copies, and never alias
  1166. the optimizer's ``group["lr"]``\s.
  1167. """
  1168. _warn_get_lr_called_within_step(self)
  1169. if self._is_initial:
  1170. return _param_groups_val_list(self.optimizer, "lr")
  1171. elif self._step_count == 1 and self.last_epoch > 0:
  1172. return [
  1173. self.eta_min
  1174. + (base_lr - self.eta_min)
  1175. * (1 + math.cos((self.last_epoch) * math.pi / self.T_max))
  1176. / 2
  1177. for base_lr, group in zip(
  1178. self.base_lrs, self.optimizer.param_groups, strict=True
  1179. )
  1180. ]
  1181. elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
  1182. return [
  1183. group["lr"]
  1184. + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
  1185. for base_lr, group in zip(
  1186. self.base_lrs, self.optimizer.param_groups, strict=True
  1187. )
  1188. ]
  1189. return [
  1190. (1 + math.cos(math.pi * self.last_epoch / self.T_max))
  1191. / (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max))
  1192. * (group["lr"] - self.eta_min)
  1193. + self.eta_min
  1194. for group in self.optimizer.param_groups
  1195. ]
  1196. def _get_closed_form_lr(self) -> list[float | Tensor]:
  1197. r"""Compute learning rates for each of the optimizer's
  1198. :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
  1199. a closed-form formula.
  1200. Uses :attr:`base_lrs` to compute learning rates. This method is called
  1201. when an epoch is passed to :meth:`step`.
  1202. Returns:
  1203. list[float | Tensor]: A :class:`list` of learning rates for each of
  1204. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  1205. same types as their current ``group["lr"]``\s.
  1206. """
  1207. return [
  1208. self.eta_min
  1209. + (base_lr - self.eta_min)
  1210. * (1 + math.cos(math.pi * self.last_epoch / self.T_max))
  1211. / 2
  1212. for base_lr in self.base_lrs
  1213. ]
  1214. class ChainedScheduler(LRScheduler):
  1215. """Chains a list of learning rate schedulers.
  1216. Takes in a sequence of chainable learning rate schedulers and calls their
  1217. step() functions consecutively in just one call to step().
  1218. Args:
  1219. schedulers (sequence): sequence of chained schedulers.
  1220. optimizer (Optimizer, optional): Wrapped optimizer. Default: None.
  1221. Example:
  1222. >>> # xdoctest: +SKIP
  1223. >>> # Assuming optimizer uses lr = 0.05 for all groups
  1224. >>> # lr = 0.05 if epoch == 0
  1225. >>> # lr = 0.0450 if epoch == 1
  1226. >>> # lr = 0.0405 if epoch == 2
  1227. >>> # ...
  1228. >>> # lr = 0.00675 if epoch == 19
  1229. >>> # lr = 0.06078 if epoch == 20
  1230. >>> # lr = 0.05470 if epoch == 21
  1231. >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
  1232. >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
  1233. >>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer)
  1234. >>> for epoch in range(100):
  1235. >>> train(...)
  1236. >>> validate(...)
  1237. >>> scheduler.step()
  1238. .. image:: ../scripts/lr_scheduler_images/ChainedScheduler.png
  1239. """
  1240. def __init__(
  1241. self, schedulers: Sequence[LRScheduler], optimizer: Optimizer | None = None
  1242. ) -> None: # noqa: D107
  1243. if len(schedulers) < 1:
  1244. raise ValueError(
  1245. f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler."
  1246. )
  1247. optimizer = optimizer or schedulers[0].optimizer
  1248. for scheduler_idx, scheduler in enumerate(schedulers):
  1249. if not hasattr(scheduler, "optimizer"):
  1250. raise TypeError(
  1251. f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute."
  1252. )
  1253. if isinstance(scheduler, ReduceLROnPlateau):
  1254. raise ValueError(
  1255. f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it "
  1256. "requires additional kwargs to be specified when calling `step`, "
  1257. f"but got one at index {scheduler_idx} in the given schedulers sequence."
  1258. )
  1259. if optimizer != scheduler.optimizer:
  1260. raise ValueError(
  1261. f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but "
  1262. f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, "
  1263. f"which is different from {optimizer.__class__.__name__}."
  1264. )
  1265. self._schedulers = schedulers
  1266. self.optimizer = optimizer
  1267. self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr")
  1268. def step(self) -> None: # type: ignore[override]
  1269. """Perform a step."""
  1270. for scheduler in self._schedulers:
  1271. scheduler.step()
  1272. self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr")
  1273. @override
  1274. def state_dict(self) -> dict[str, Any]:
  1275. """Return the state of the scheduler as a :class:`dict`.
  1276. It contains an entry for every variable in ``self.__dict__`` which
  1277. is not the optimizer.
  1278. The wrapped scheduler states will also be saved.
  1279. """
  1280. state_dict = {
  1281. key: value
  1282. for key, value in self.__dict__.items()
  1283. if key not in ("optimizer", "_schedulers")
  1284. }
  1285. state_dict["_schedulers"] = [None] * len(self._schedulers)
  1286. for idx, s in enumerate(self._schedulers):
  1287. # pyrefly: ignore [unsupported-operation]
  1288. state_dict["_schedulers"][idx] = s.state_dict()
  1289. return state_dict
  1290. @override
  1291. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  1292. """Load the scheduler's state.
  1293. Args:
  1294. state_dict (dict): scheduler state. Should be an object returned
  1295. from a call to :meth:`state_dict`.
  1296. """
  1297. _schedulers = state_dict.pop("_schedulers")
  1298. self.__dict__.update(state_dict)
  1299. # Restore state_dict keys in order to prevent side effects
  1300. # https://github.com/pytorch/pytorch/issues/32756
  1301. state_dict["_schedulers"] = _schedulers
  1302. for idx, s in enumerate(_schedulers):
  1303. self._schedulers[idx].load_state_dict(s)
  1304. class ReduceLROnPlateau(LRScheduler):
  1305. """Reduce learning rate when a metric has stopped improving.
  1306. Models often benefit from reducing the learning rate by a factor
  1307. of 2-10 once learning stagnates. This scheduler reads a metrics
  1308. quantity and if no improvement is seen for a 'patience' number
  1309. of epochs, the learning rate is reduced.
  1310. Args:
  1311. optimizer (Optimizer): Wrapped optimizer.
  1312. mode (str): One of `min`, `max`. In `min` mode, lr will
  1313. be reduced when the quantity monitored has stopped
  1314. decreasing; in `max` mode it will be reduced when the
  1315. quantity monitored has stopped increasing. Default: 'min'.
  1316. factor (float): Factor by which the learning rate will be
  1317. reduced. new_lr = lr * factor. Default: 0.1.
  1318. patience (int): The number of allowed epochs with no improvement after
  1319. which the learning rate will be reduced.
  1320. For example, consider the case of having no patience (`patience = 0`).
  1321. In the first epoch, a baseline is established and is always considered good as there's no previous baseline.
  1322. In the second epoch, if the performance is worse than the baseline,
  1323. we have what is considered an intolerable epoch.
  1324. Since the count of intolerable epochs (1) is greater than the patience level (0),
  1325. the learning rate is reduced at the end of this epoch.
  1326. From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch
  1327. if the performance is worse than the baseline. If the performance improves or remains the same,
  1328. the learning rate is not adjusted.
  1329. Default: 10.
  1330. threshold (float): Threshold for measuring the new optimum,
  1331. to only focus on significant changes. Default: 1e-4.
  1332. threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
  1333. dynamic_threshold = best * ( 1 + threshold ) in 'max'
  1334. mode or best * ( 1 - threshold ) in `min` mode.
  1335. In `abs` mode, dynamic_threshold = best + threshold in
  1336. `max` mode or best - threshold in `min` mode. Default: 'rel'.
  1337. cooldown (int): Number of epochs to wait before resuming
  1338. normal operation after lr has been reduced. Default: 0.
  1339. min_lr (float or list): A scalar or a list of scalars. A
  1340. lower bound on the learning rate of all param groups
  1341. or each group respectively. Default: 0.
  1342. eps (float): Minimal decay applied to lr. If the difference
  1343. between new and old lr is smaller than eps, the update is
  1344. ignored. Default: 1e-8.
  1345. Example:
  1346. >>> # xdoctest: +SKIP
  1347. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  1348. >>> scheduler = ReduceLROnPlateau(optimizer, "min")
  1349. >>> for epoch in range(10):
  1350. >>> train(...)
  1351. >>> val_loss = validate(...)
  1352. >>> # Note that step should be called after validate()
  1353. >>> scheduler.step(val_loss)
  1354. .. image:: ../scripts/lr_scheduler_images/ReduceLROnPlateau.png
  1355. """
  1356. def __init__(
  1357. self,
  1358. optimizer: Optimizer,
  1359. mode: Literal["min", "max"] = "min",
  1360. factor: float = 0.1,
  1361. patience: int = 10,
  1362. threshold: float = 1e-4,
  1363. threshold_mode: Literal["rel", "abs"] = "rel",
  1364. cooldown: int = 0,
  1365. min_lr: list[float] | float = 0,
  1366. eps: float = 1e-8,
  1367. ) -> None: # noqa: D107
  1368. if factor >= 1.0:
  1369. raise ValueError("Factor should be < 1.0.")
  1370. self.factor = factor
  1371. # Attach optimizer
  1372. if not isinstance(optimizer, Optimizer):
  1373. raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
  1374. self.optimizer = optimizer
  1375. if isinstance(min_lr, (list, tuple)):
  1376. if len(min_lr) != len(optimizer.param_groups):
  1377. raise ValueError(
  1378. f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}"
  1379. )
  1380. self.default_min_lr = None
  1381. self.min_lrs = list(min_lr)
  1382. else:
  1383. # pyrefly: ignore [bad-assignment]
  1384. self.default_min_lr = min_lr
  1385. self.min_lrs = [min_lr] * len(optimizer.param_groups)
  1386. self.patience = patience
  1387. self.cooldown = cooldown
  1388. self.eps = eps
  1389. self.last_epoch = 0
  1390. self._last_lr = _param_groups_val_list(self.optimizer, "lr")
  1391. self._init_is_better(
  1392. mode=mode, threshold=threshold, threshold_mode=threshold_mode
  1393. )
  1394. self._reset()
  1395. def _reset(self) -> None:
  1396. """Reset num_bad_epochs counter and cooldown counter."""
  1397. self.best = self.mode_worse
  1398. self.cooldown_counter = 0
  1399. self.num_bad_epochs = 0
  1400. def step(self, metrics: SupportsFloat, epoch=None) -> None: # type: ignore[override]
  1401. """Perform a step."""
  1402. # convert `metrics` to float, in case it's a zero-dim Tensor
  1403. current = float(metrics)
  1404. if epoch is None:
  1405. epoch = self.last_epoch + 1
  1406. else:
  1407. warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning, stacklevel=2)
  1408. self.last_epoch = epoch
  1409. if self._is_better(current, self.best):
  1410. self.best = current
  1411. self.num_bad_epochs = 0
  1412. else:
  1413. self.num_bad_epochs += 1
  1414. if self.in_cooldown:
  1415. self.cooldown_counter -= 1
  1416. self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
  1417. if self.num_bad_epochs > self.patience:
  1418. self._reduce_lr(epoch)
  1419. self.cooldown_counter = self.cooldown
  1420. self.num_bad_epochs = 0
  1421. self._last_lr = _param_groups_val_list(self.optimizer, "lr")
  1422. def _reduce_lr(self, epoch) -> None:
  1423. if len(self.optimizer.param_groups) != len(self.min_lrs):
  1424. if self.default_min_lr is None:
  1425. raise RuntimeError(
  1426. "The number of param groups in the `optimizer` "
  1427. f"({len(self.optimizer.param_groups)}) differs "
  1428. f"from when `ReduceLROnPlateau` was initialized "
  1429. f"({len(self.min_lrs)}), usually due to a new "
  1430. "param group being added to the optimizer. Please "
  1431. "modify the `min_lrs` field to match the length "
  1432. "of the `optimizer` param groups."
  1433. )
  1434. else:
  1435. # pyrefly: ignore [bad-assignment]
  1436. self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
  1437. for i, param_group in enumerate(self.optimizer.param_groups):
  1438. old_lr = float(param_group["lr"])
  1439. new_lr = max(old_lr * self.factor, self.min_lrs[i])
  1440. if old_lr - new_lr > self.eps:
  1441. _update_param_group_val(param_group, "lr", new_lr)
  1442. @property
  1443. def in_cooldown(self): # noqa: D102
  1444. return self.cooldown_counter > 0
  1445. def _is_better(self, a, best): # noqa: D102
  1446. if self.mode == "min" and self.threshold_mode == "rel":
  1447. rel_epsilon = 1.0 - self.threshold
  1448. return a < best * rel_epsilon
  1449. elif self.mode == "min" and self.threshold_mode == "abs":
  1450. return a < best - self.threshold
  1451. elif self.mode == "max" and self.threshold_mode == "rel":
  1452. rel_epsilon = self.threshold + 1.0
  1453. return a > best * rel_epsilon
  1454. else: # mode == 'max' and epsilon_mode == 'abs':
  1455. return a > best + self.threshold
  1456. def _init_is_better(self, mode, threshold, threshold_mode) -> None:
  1457. if mode not in {"min", "max"}:
  1458. raise ValueError("mode " + mode + " is unknown!")
  1459. if threshold_mode not in {"rel", "abs"}:
  1460. raise ValueError("threshold mode " + threshold_mode + " is unknown!")
  1461. # the worse value for the chosen mode
  1462. if mode == "min":
  1463. self.mode_worse = inf
  1464. else: # mode == 'max':
  1465. self.mode_worse = -inf
  1466. self.mode = mode
  1467. self.threshold = threshold
  1468. self.threshold_mode = threshold_mode
  1469. @override
  1470. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  1471. """Load the scheduler's state."""
  1472. self.__dict__.update(state_dict)
  1473. self._init_is_better(
  1474. mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode
  1475. )
  1476. class CyclicLR(LRScheduler):
  1477. r"""Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR).
  1478. The policy cycles the learning rate between two boundaries with a constant frequency,
  1479. as detailed in the paper `Cyclical Learning Rates for Training Neural Networks`_.
  1480. The distance between the two boundaries can be scaled on a per-iteration
  1481. or per-cycle basis.
  1482. Cyclical learning rate policy changes the learning rate after every batch.
  1483. `step` should be called after a batch has been used for training.
  1484. This class has three built-in policies, as put forth in the paper:
  1485. * "triangular": A basic triangular cycle without amplitude scaling.
  1486. * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
  1487. * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
  1488. at each cycle iteration.
  1489. This implementation was adapted from the github repo: `bckenstler/CLR`_
  1490. Args:
  1491. optimizer (Optimizer): Wrapped optimizer.
  1492. base_lr (float or list): Initial learning rate which is the
  1493. lower boundary in the cycle for each parameter group.
  1494. max_lr (float or list): Upper learning rate boundaries in the cycle
  1495. for each parameter group. Functionally,
  1496. it defines the cycle amplitude (max_lr - base_lr).
  1497. The lr at any cycle is the sum of base_lr
  1498. and some scaling of the amplitude; therefore
  1499. max_lr may not actually be reached depending on
  1500. scaling function.
  1501. step_size_up (int): Number of training iterations in the
  1502. increasing half of a cycle. Default: 2000
  1503. step_size_down (int): Number of training iterations in the
  1504. decreasing half of a cycle. If step_size_down is None,
  1505. it is set to step_size_up. Default: None
  1506. mode (str): One of {triangular, triangular2, exp_range}.
  1507. Values correspond to policies detailed above.
  1508. If scale_fn is not None, this argument is ignored.
  1509. Default: 'triangular'
  1510. gamma (float): Constant in 'exp_range' scaling function:
  1511. gamma**(cycle iterations)
  1512. Default: 1.0
  1513. scale_fn (function): Custom scaling policy defined by a single
  1514. argument lambda function, where
  1515. 0 <= scale_fn(x) <= 1 for all x >= 0.
  1516. If specified, then 'mode' is ignored.
  1517. Default: None
  1518. scale_mode (str): {'cycle', 'iterations'}.
  1519. Defines whether scale_fn is evaluated on
  1520. cycle number or cycle iterations (training
  1521. iterations since start of cycle).
  1522. Default: 'cycle'
  1523. cycle_momentum (bool): If ``True``, momentum is cycled inversely
  1524. to learning rate between 'base_momentum' and 'max_momentum'.
  1525. Default: True
  1526. base_momentum (float or list): Lower momentum boundaries in the cycle
  1527. for each parameter group. Note that momentum is cycled inversely
  1528. to learning rate; at the peak of a cycle, momentum is
  1529. 'base_momentum' and learning rate is 'max_lr'.
  1530. Default: 0.8
  1531. max_momentum (float or list): Upper momentum boundaries in the cycle
  1532. for each parameter group. Functionally,
  1533. it defines the cycle amplitude (max_momentum - base_momentum).
  1534. The momentum at any cycle is the difference of max_momentum
  1535. and some scaling of the amplitude; therefore
  1536. base_momentum may not actually be reached depending on
  1537. scaling function. Note that momentum is cycled inversely
  1538. to learning rate; at the start of a cycle, momentum is 'max_momentum'
  1539. and learning rate is 'base_lr'
  1540. Default: 0.9
  1541. last_epoch (int): The index of the last batch. This parameter is used when
  1542. resuming a training job. Since `step()` should be invoked after each
  1543. batch instead of after each epoch, this number represents the total
  1544. number of *batches* computed, not the total number of epochs computed.
  1545. When last_epoch=-1, the schedule is started from the beginning.
  1546. Default: -1
  1547. Example:
  1548. >>> # xdoctest: +SKIP
  1549. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  1550. >>> scheduler = torch.optim.lr_scheduler.CyclicLR(
  1551. ... optimizer,
  1552. ... base_lr=0.01,
  1553. ... max_lr=0.1,
  1554. ... step_size_up=10,
  1555. ... )
  1556. >>> data_loader = torch.utils.data.DataLoader(...)
  1557. >>> for epoch in range(10):
  1558. >>> for batch in data_loader:
  1559. >>> train_batch(...)
  1560. >>> scheduler.step()
  1561. .. image:: ../scripts/lr_scheduler_images/CyclicLR.png
  1562. .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
  1563. .. _bckenstler/CLR: https://github.com/bckenstler/CLR
  1564. """
  1565. def __init__(
  1566. self,
  1567. optimizer: Optimizer,
  1568. base_lr: float | list[float],
  1569. max_lr: float | list[float],
  1570. step_size_up: int = 2000,
  1571. step_size_down: int | None = None,
  1572. mode: Literal["triangular", "triangular2", "exp_range"] = "triangular",
  1573. gamma: float = 1.0,
  1574. scale_fn: Callable[[float], float] | None = None,
  1575. scale_mode: Literal["cycle", "iterations"] = "cycle",
  1576. cycle_momentum: bool = True,
  1577. base_momentum: float = 0.8,
  1578. max_momentum: float = 0.9,
  1579. last_epoch: int = -1,
  1580. ) -> None: # noqa: D107
  1581. # Attach optimizer
  1582. if not isinstance(optimizer, Optimizer):
  1583. raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
  1584. self.optimizer = optimizer
  1585. base_lrs = _format_param("base_lr", optimizer, base_lr)
  1586. if last_epoch == -1:
  1587. for lr, group in zip(base_lrs, optimizer.param_groups, strict=True):
  1588. _update_param_group_val(group, "lr", lr)
  1589. self.max_lrs = _format_param("max_lr", optimizer, max_lr)
  1590. # pyrefly: ignore [bad-assignment]
  1591. step_size_up = float(step_size_up)
  1592. step_size_down = (
  1593. # pyrefly: ignore [bad-assignment]
  1594. float(step_size_down) if step_size_down is not None else step_size_up
  1595. )
  1596. # pyrefly: ignore [unsupported-operation]
  1597. self.total_size = step_size_up + step_size_down
  1598. self.step_ratio = step_size_up / self.total_size
  1599. if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None:
  1600. raise ValueError("mode is invalid and scale_fn is None")
  1601. self.mode = mode
  1602. self.gamma = gamma
  1603. self._scale_fn_ref: Callable[[float], float]
  1604. self._scale_fn_custom = scale_fn
  1605. self.scale_mode = scale_mode
  1606. self._init_scale_fn()
  1607. self.cycle_momentum = cycle_momentum
  1608. if cycle_momentum:
  1609. if (
  1610. "momentum" not in optimizer.defaults
  1611. and "betas" not in optimizer.defaults
  1612. ):
  1613. raise ValueError(
  1614. "optimizer must support momentum or beta1 with `cycle_momentum` option enabled"
  1615. )
  1616. self.use_beta1 = "betas" in self.optimizer.defaults
  1617. self.base_momentums = _format_param(
  1618. "base_momentum", optimizer, base_momentum
  1619. )
  1620. self.max_momentums = _format_param("max_momentum", optimizer, max_momentum)
  1621. if last_epoch == -1:
  1622. for m_momentum, b_momentum, group in zip(
  1623. self.max_momentums,
  1624. self.base_momentums,
  1625. optimizer.param_groups,
  1626. strict=True,
  1627. ):
  1628. if self.use_beta1:
  1629. group["betas"] = (m_momentum, *group["betas"][1:])
  1630. else:
  1631. group["momentum"] = m_momentum
  1632. group["max_momentum"] = m_momentum
  1633. group["base_momentum"] = b_momentum
  1634. super().__init__(optimizer, last_epoch)
  1635. self.base_lrs = base_lrs
  1636. def _init_scale_fn(self) -> None:
  1637. if self._scale_fn_custom is not None:
  1638. return
  1639. if self.mode == "triangular":
  1640. self._scale_fn_ref = self._triangular_scale_fn
  1641. self.scale_mode = "cycle"
  1642. elif self.mode == "triangular2":
  1643. self._scale_fn_ref = self._triangular2_scale_fn
  1644. self.scale_mode = "cycle"
  1645. elif self.mode == "exp_range":
  1646. self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma)
  1647. self.scale_mode = "iterations"
  1648. def scale_fn(self, x) -> float:
  1649. """Get the scaling policy."""
  1650. if self._scale_fn_custom is not None:
  1651. return self._scale_fn_custom(x)
  1652. else:
  1653. return self._scale_fn_ref(x) # static method
  1654. @staticmethod
  1655. def _triangular_scale_fn(x: float) -> float:
  1656. return 1.0
  1657. @staticmethod
  1658. def _triangular2_scale_fn(x: float) -> float:
  1659. return 1 / (2.0 ** (x - 1))
  1660. @staticmethod
  1661. def _exp_range_scale_fn(gamma: float, x: float) -> float:
  1662. return gamma**x
  1663. @override
  1664. def get_lr(self) -> list[float | Tensor]:
  1665. r"""Compute the next learning rate for each of the optimizer's
  1666. :attr:`~torch.optim.Optimizer.param_groups`.
  1667. Advances each ``group["lr"]`` in the optimizer's
  1668. :attr:`~torch.optim.Optimizer.param_groups` along a cycle between the
  1669. group's ``base_lr`` and ``max_lr`` using :meth:`scale_fn`.
  1670. Returns:
  1671. list[float | Tensor]: A :class:`list` of learning rates for each of
  1672. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  1673. same types as their current ``group["lr"]``\s.
  1674. .. note::
  1675. If you're trying to inspect the most recent learning rate, use
  1676. :meth:`get_last_lr()` instead.
  1677. .. note::
  1678. The returned :class:`~torch.Tensor`\s are copies, and never alias
  1679. the optimizer's ``group["lr"]``\s.
  1680. .. note::
  1681. This method treats :attr:`last_epoch` as the index of the previous
  1682. batch.
  1683. .. note::
  1684. When :attr:`cycle_momentum` is ``True``, this method has a side
  1685. effect of updating the optimizer's momentum.
  1686. """
  1687. _warn_get_lr_called_within_step(self)
  1688. cycle = math.floor(1 + self.last_epoch / self.total_size)
  1689. x = 1.0 + self.last_epoch / self.total_size - cycle
  1690. if x <= self.step_ratio:
  1691. scale_factor = x / self.step_ratio
  1692. else:
  1693. scale_factor = (x - 1) / (self.step_ratio - 1)
  1694. lrs = []
  1695. for base_lr, max_lr in zip(self.base_lrs, self.max_lrs, strict=True):
  1696. base_height = (max_lr - base_lr) * scale_factor
  1697. if self.scale_mode == "cycle":
  1698. lr = base_lr + base_height * self.scale_fn(cycle)
  1699. else:
  1700. lr = base_lr + base_height * self.scale_fn(self.last_epoch)
  1701. lrs.append(lr)
  1702. if self.cycle_momentum:
  1703. momentums = []
  1704. for base_momentum, max_momentum in zip(
  1705. self.base_momentums, self.max_momentums, strict=True
  1706. ):
  1707. base_height = (max_momentum - base_momentum) * scale_factor
  1708. if self.scale_mode == "cycle":
  1709. momentum = max_momentum - base_height * self.scale_fn(cycle)
  1710. else:
  1711. momentum = max_momentum - base_height * self.scale_fn(
  1712. self.last_epoch
  1713. )
  1714. momentums.append(momentum)
  1715. for param_group, momentum in zip(
  1716. self.optimizer.param_groups, momentums, strict=True
  1717. ):
  1718. if self.use_beta1:
  1719. param_group["betas"] = (momentum, *param_group["betas"][1:])
  1720. else:
  1721. param_group["momentum"] = momentum
  1722. return lrs
  1723. @override
  1724. def state_dict(self) -> dict[str, Any]: # noqa: D102
  1725. """Return the state of the scheduler as a :class:`dict`.
  1726. It contains an entry for every variable in ``self.__dict__`` which
  1727. is not the optimizer.
  1728. The learning rate lambda functions will only be saved if they are callable objects
  1729. and not if they are functions or lambdas.
  1730. When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
  1731. """
  1732. state = super().state_dict()
  1733. # We are dropping the `_scale_fn_ref` attribute because it is a
  1734. # `weakref.WeakMethod` and can't be pickled.
  1735. state.pop("_scale_fn_ref", None)
  1736. fn = state.pop("_scale_fn_custom")
  1737. state["_scale_fn_custom"] = None
  1738. if fn is not None and not isinstance(fn, types.FunctionType):
  1739. # The _scale_fn_custom will only be saved if it is a callable object
  1740. # and not if it is a function or lambda.
  1741. state["_scale_fn_custom"] = fn.__dict__.copy()
  1742. return state
  1743. @override
  1744. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  1745. """Load the scheduler's state."""
  1746. fn = state_dict.pop("_scale_fn_custom")
  1747. super().load_state_dict(state_dict)
  1748. if fn is not None:
  1749. self._scale_fn_custom.__dict__.update(fn)
  1750. self._init_scale_fn()
  1751. class CosineAnnealingWarmRestarts(LRScheduler):
  1752. r"""Set the learning rate of each parameter group using a cosine annealing schedule.
  1753. The :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
  1754. is the number of epochs since the last restart and :math:`T_{i}` is the number
  1755. of epochs between two warm restarts in SGDR:
  1756. .. math::
  1757. \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
  1758. \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
  1759. When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
  1760. When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
  1761. It has been proposed in
  1762. `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
  1763. Args:
  1764. optimizer (Optimizer): Wrapped optimizer.
  1765. T_0 (int): Number of iterations until the first restart.
  1766. T_mult (int, optional): A factor by which :math:`T_{i}` increases after a restart. Default: 1.
  1767. eta_min (float, optional): Minimum learning rate. Default: 0.
  1768. last_epoch (int, optional): The index of the last epoch. Default: -1.
  1769. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
  1770. https://arxiv.org/abs/1608.03983
  1771. Example:
  1772. >>> # xdoctest: +SKIP
  1773. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
  1774. >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
  1775. ... optimizer, T_0=20
  1776. ... )
  1777. >>> for epoch in range(100):
  1778. >>> train(...)
  1779. >>> validate(...)
  1780. >>> scheduler.step()
  1781. .. image:: ../scripts/lr_scheduler_images/CosineAnnealingWarmRestarts.png
  1782. """
  1783. def __init__(
  1784. self,
  1785. optimizer: Optimizer,
  1786. T_0: int,
  1787. T_mult: int = 1,
  1788. eta_min: float = 0.0,
  1789. last_epoch: int = -1,
  1790. ) -> None: # noqa: D107
  1791. if T_0 <= 0 or not isinstance(T_0, int):
  1792. raise ValueError(f"Expected positive integer T_0, but got {T_0}")
  1793. if T_mult < 1 or not isinstance(T_mult, int):
  1794. raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
  1795. if not isinstance(eta_min, (float, int)):
  1796. raise ValueError(
  1797. f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}"
  1798. )
  1799. self.T_0 = T_0
  1800. self.T_i = T_0
  1801. self.T_mult = T_mult
  1802. self.eta_min = eta_min
  1803. self.T_cur = last_epoch
  1804. super().__init__(optimizer, last_epoch)
  1805. @override
  1806. def get_lr(self) -> list[float | Tensor]:
  1807. r"""Compute the next learning rate for each of the optimizer's
  1808. :attr:`~torch.optim.Optimizer.param_groups`.
  1809. Computes learning rates for the optimizer's
  1810. :attr:`~torch.optim.Optimizer.param_groups` following:
  1811. .. math::
  1812. \texttt{eta\_min} + \frac{1}{2}(\texttt{base\_lr} -
  1813. \texttt{eta\_min})\left(1 + \cos\left(\pi \cdot
  1814. \frac{\texttt{T\_cur}}{\texttt{T\_i}}\right)\right)
  1815. Where :attr:`T_cur` is the number of epochs since the last restart and
  1816. :attr:`T_i` is the number of epochs between two restarts. Both
  1817. :attr:`T_cur` and :attr:`T_i` are updated in :meth:`step`, and
  1818. :attr:`T_i` becomes :attr:`T_mult` times larger after each restart.
  1819. Returns:
  1820. list[float | Tensor]: A :class:`list` of learning rates for each of
  1821. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  1822. same types as their current ``group["lr"]``\s.
  1823. .. note::
  1824. If you're trying to inspect the most recent learning rate, use
  1825. :meth:`get_last_lr()` instead.
  1826. .. note::
  1827. The returned :class:`~torch.Tensor`\s are copies, and never alias
  1828. the optimizer's ``group["lr"]``\s.
  1829. """
  1830. _warn_get_lr_called_within_step(self)
  1831. return [
  1832. self.eta_min
  1833. + (base_lr - self.eta_min)
  1834. * (1 + math.cos(math.pi * self.T_cur / self.T_i))
  1835. / 2
  1836. for base_lr in self.base_lrs
  1837. ]
  1838. @override
  1839. def step(self, epoch=None) -> None:
  1840. """Step could be called after every batch update.
  1841. Example:
  1842. >>> # xdoctest: +SKIP("Undefined vars")
  1843. >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
  1844. >>> iters = len(dataloader)
  1845. >>> for epoch in range(20):
  1846. >>> for i, sample in enumerate(dataloader):
  1847. >>> inputs, labels = sample['inputs'], sample['labels']
  1848. >>> optimizer.zero_grad()
  1849. >>> outputs = net(inputs)
  1850. >>> loss = criterion(outputs, labels)
  1851. >>> loss.backward()
  1852. >>> optimizer.step()
  1853. >>> scheduler.step(epoch + i / iters)
  1854. This function can be called in an interleaved way.
  1855. Example:
  1856. >>> # xdoctest: +SKIP("Undefined vars")
  1857. >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
  1858. >>> for epoch in range(20):
  1859. >>> scheduler.step()
  1860. >>> scheduler.step(26)
  1861. >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
  1862. """
  1863. if epoch is None and self.last_epoch < 0:
  1864. epoch = 0
  1865. if epoch is None:
  1866. epoch = self.last_epoch + 1
  1867. self.T_cur = self.T_cur + 1
  1868. if self.T_cur >= self.T_i:
  1869. self.T_cur = self.T_cur % self.T_i
  1870. self.T_i = self.T_i * self.T_mult
  1871. else:
  1872. if epoch < 0:
  1873. raise ValueError(f"Expected non-negative epoch, but got {epoch}")
  1874. if epoch >= self.T_0:
  1875. if self.T_mult == 1:
  1876. self.T_cur = epoch % self.T_0
  1877. else:
  1878. n = int(
  1879. math.log(
  1880. (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult
  1881. )
  1882. )
  1883. self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (
  1884. self.T_mult - 1
  1885. )
  1886. self.T_i = self.T_0 * self.T_mult ** (n)
  1887. else:
  1888. self.T_i = self.T_0
  1889. self.T_cur = epoch
  1890. self.last_epoch = math.floor(epoch)
  1891. with _enable_get_lr_call(self):
  1892. for param_group, lr in zip(
  1893. self.optimizer.param_groups, self.get_lr(), strict=True
  1894. ):
  1895. _update_param_group_val(param_group, "lr", lr)
  1896. self._last_lr = _param_groups_val_list(self.optimizer, "lr")
  1897. class _SchedulePhase(TypedDict):
  1898. end_step: float
  1899. start_lr: str
  1900. end_lr: str
  1901. start_momentum: str
  1902. end_momentum: str
  1903. class OneCycleLR(LRScheduler):
  1904. r"""Sets the learning rate of each parameter group according to the 1cycle learning rate policy.
  1905. The 1cycle policy anneals the learning rate from an initial learning rate to some maximum
  1906. learning rate and then from that maximum learning rate to some minimum learning rate much
  1907. lower than the initial learning rate.
  1908. This policy was initially described in the paper `Super-Convergence:
  1909. Very Fast Training of Neural Networks Using Large Learning Rates`_.
  1910. The 1cycle learning rate policy changes the learning rate after every batch.
  1911. `step` should be called after a batch has been used for training.
  1912. This scheduler is not chainable.
  1913. Note also that the total number of steps in the cycle can be determined in one
  1914. of two ways (listed in order of precedence):
  1915. #. A value for total_steps is explicitly provided.
  1916. #. A number of epochs (epochs) and a number of steps per epoch
  1917. (steps_per_epoch) are provided.
  1918. In this case, the number of total steps is inferred by
  1919. total_steps = epochs * steps_per_epoch
  1920. You must either provide a value for total_steps or provide a value for both
  1921. epochs and steps_per_epoch.
  1922. The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
  1923. claims that "unpublished work has shown even better results by using only two phases". To
  1924. mimic the behaviour of the original paper instead, set ``three_phase=True``.
  1925. Args:
  1926. optimizer (Optimizer): Wrapped optimizer.
  1927. max_lr (float or list): Upper learning rate boundaries in the cycle
  1928. for each parameter group.
  1929. total_steps (int): The total number of steps in the cycle. Note that
  1930. if a value is not provided here, then it must be inferred by providing
  1931. a value for epochs and steps_per_epoch.
  1932. Default: None
  1933. epochs (int): The number of epochs to train for. This is used along
  1934. with steps_per_epoch in order to infer the total number of steps in the cycle
  1935. if a value for total_steps is not provided.
  1936. Default: None
  1937. steps_per_epoch (int): The number of steps per epoch to train for. This is
  1938. used along with epochs in order to infer the total number of steps in the
  1939. cycle if a value for total_steps is not provided.
  1940. Default: None
  1941. pct_start (float): The percentage of the cycle (in number of steps) spent
  1942. increasing the learning rate.
  1943. Default: 0.3
  1944. anneal_strategy (str): {'cos', 'linear'}
  1945. Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
  1946. linear annealing.
  1947. Default: 'cos'
  1948. cycle_momentum (bool): If ``True``, momentum is cycled inversely
  1949. to learning rate between 'base_momentum' and 'max_momentum'.
  1950. Default: True
  1951. base_momentum (float or list): Lower momentum boundaries in the cycle
  1952. for each parameter group. Note that momentum is cycled inversely
  1953. to learning rate; at the peak of a cycle, momentum is
  1954. 'base_momentum' and learning rate is 'max_lr'.
  1955. Default: 0.85
  1956. max_momentum (float or list): Upper momentum boundaries in the cycle
  1957. for each parameter group. Functionally,
  1958. it defines the cycle amplitude (max_momentum - base_momentum).
  1959. Note that momentum is cycled inversely
  1960. to learning rate; at the start of a cycle, momentum is 'max_momentum'
  1961. and learning rate is 'base_lr'
  1962. Default: 0.95
  1963. div_factor (float): Determines the initial learning rate via
  1964. initial_lr = max_lr/div_factor
  1965. Default: 25
  1966. final_div_factor (float): Determines the minimum learning rate via
  1967. min_lr = initial_lr/final_div_factor
  1968. Default: 1e4
  1969. three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
  1970. learning rate according to 'final_div_factor' instead of modifying the second
  1971. phase (the first two phases will be symmetrical about the step indicated by
  1972. 'pct_start').
  1973. last_epoch (int): The index of the last batch. This parameter is used when
  1974. resuming a training job. Since `step()` should be invoked after each
  1975. batch instead of after each epoch, this number represents the total
  1976. number of *batches* computed, not the total number of epochs computed.
  1977. When last_epoch=-1, the schedule is started from the beginning.
  1978. Default: -1
  1979. Example:
  1980. >>> # xdoctest: +SKIP
  1981. >>> data_loader = torch.utils.data.DataLoader(...)
  1982. >>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
  1983. >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(
  1984. ... optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10
  1985. ... )
  1986. >>> for epoch in range(10):
  1987. >>> for batch in data_loader:
  1988. >>> train_batch(...)
  1989. >>> optimizer.step()
  1990. >>> scheduler.step()
  1991. .. image:: ../scripts/lr_scheduler_images/OneCycleLR.png
  1992. .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
  1993. https://arxiv.org/abs/1708.07120
  1994. """
  1995. def __init__(
  1996. self,
  1997. optimizer: Optimizer,
  1998. max_lr: float | list[float],
  1999. total_steps: int | None = None,
  2000. epochs: int | None = None,
  2001. steps_per_epoch: int | None = None,
  2002. pct_start: float = 0.3,
  2003. anneal_strategy: Literal["cos", "linear"] = "cos",
  2004. cycle_momentum: bool = True,
  2005. base_momentum: float | list[float] = 0.85,
  2006. max_momentum: float | list[float] = 0.95,
  2007. div_factor: float = 25.0,
  2008. final_div_factor: float = 1e4,
  2009. three_phase: bool = False,
  2010. last_epoch: int = -1,
  2011. ) -> None: # noqa: D107
  2012. # Validate optimizer
  2013. if not isinstance(optimizer, Optimizer):
  2014. raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
  2015. self.optimizer = optimizer
  2016. # Validate total_steps
  2017. if total_steps is not None:
  2018. if total_steps <= 0 or not isinstance(total_steps, int):
  2019. raise ValueError(
  2020. f"Expected positive integer total_steps, but got {total_steps}"
  2021. )
  2022. self.total_steps = total_steps
  2023. elif epochs is not None and steps_per_epoch is not None:
  2024. if not isinstance(epochs, int) or epochs <= 0:
  2025. raise ValueError(f"Expected positive integer epochs, but got {epochs}")
  2026. if not isinstance(steps_per_epoch, int) or steps_per_epoch <= 0:
  2027. raise ValueError(
  2028. f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}"
  2029. )
  2030. self.total_steps = epochs * steps_per_epoch
  2031. else:
  2032. raise ValueError(
  2033. "You must define either total_steps OR (epochs AND steps_per_epoch)"
  2034. )
  2035. self._schedule_phases: list[_SchedulePhase]
  2036. if three_phase:
  2037. self._schedule_phases = [
  2038. {
  2039. "end_step": float(pct_start * self.total_steps) - 1,
  2040. "start_lr": "initial_lr",
  2041. "end_lr": "max_lr",
  2042. "start_momentum": "max_momentum",
  2043. "end_momentum": "base_momentum",
  2044. },
  2045. {
  2046. "end_step": float(2 * pct_start * self.total_steps) - 2,
  2047. "start_lr": "max_lr",
  2048. "end_lr": "initial_lr",
  2049. "start_momentum": "base_momentum",
  2050. "end_momentum": "max_momentum",
  2051. },
  2052. {
  2053. "end_step": self.total_steps - 1,
  2054. "start_lr": "initial_lr",
  2055. "end_lr": "min_lr",
  2056. "start_momentum": "max_momentum",
  2057. "end_momentum": "max_momentum",
  2058. },
  2059. ]
  2060. else:
  2061. self._schedule_phases = [
  2062. {
  2063. "end_step": float(pct_start * self.total_steps) - 1,
  2064. "start_lr": "initial_lr",
  2065. "end_lr": "max_lr",
  2066. "start_momentum": "max_momentum",
  2067. "end_momentum": "base_momentum",
  2068. },
  2069. {
  2070. "end_step": self.total_steps - 1,
  2071. "start_lr": "max_lr",
  2072. "end_lr": "min_lr",
  2073. "start_momentum": "base_momentum",
  2074. "end_momentum": "max_momentum",
  2075. },
  2076. ]
  2077. # Validate pct_start
  2078. if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
  2079. raise ValueError(
  2080. f"Expected float between 0 and 1 pct_start, but got {pct_start}"
  2081. )
  2082. # Validate anneal_strategy
  2083. if anneal_strategy not in ["cos", "linear"]:
  2084. raise ValueError(
  2085. f"anneal_strategy must be one of 'cos' or 'linear', instead got {anneal_strategy}"
  2086. )
  2087. else:
  2088. self._anneal_func_type = anneal_strategy
  2089. # Initialize learning rate variables
  2090. max_lrs = _format_param("max_lr", self.optimizer, max_lr)
  2091. if last_epoch == -1:
  2092. for idx, group in enumerate(self.optimizer.param_groups):
  2093. group["initial_lr"] = max_lrs[idx] / div_factor
  2094. group["max_lr"] = max_lrs[idx]
  2095. group["min_lr"] = group["initial_lr"] / final_div_factor
  2096. # Initialize momentum variables
  2097. self.cycle_momentum = cycle_momentum
  2098. if self.cycle_momentum:
  2099. if (
  2100. "momentum" not in self.optimizer.defaults
  2101. and "betas" not in self.optimizer.defaults
  2102. ):
  2103. raise ValueError(
  2104. "optimizer must support momentum or beta1 with `cycle_momentum` option enabled"
  2105. )
  2106. self.use_beta1 = "betas" in self.optimizer.defaults
  2107. max_momentums = _format_param("max_momentum", optimizer, max_momentum)
  2108. base_momentums = _format_param("base_momentum", optimizer, base_momentum)
  2109. if last_epoch == -1:
  2110. for m_momentum, b_momentum, group in zip(
  2111. max_momentums, base_momentums, optimizer.param_groups, strict=True
  2112. ):
  2113. if self.use_beta1:
  2114. group["betas"] = (m_momentum, *group["betas"][1:])
  2115. else:
  2116. group["momentum"] = m_momentum
  2117. group["max_momentum"] = m_momentum
  2118. group["base_momentum"] = b_momentum
  2119. super().__init__(optimizer, last_epoch)
  2120. def _anneal_func(self, *args, **kwargs):
  2121. if hasattr(self, "_anneal_func_type"):
  2122. if self._anneal_func_type == "cos":
  2123. return self._annealing_cos(*args, **kwargs)
  2124. elif self._anneal_func_type == "linear":
  2125. return self._annealing_linear(*args, **kwargs)
  2126. else:
  2127. raise ValueError(f"Unknown _anneal_func_type: {self._anneal_func_type}")
  2128. else:
  2129. # For BC
  2130. return self.anneal_func(*args, **kwargs) # type: ignore[attr-defined]
  2131. @staticmethod
  2132. def _annealing_cos(start, end, pct):
  2133. """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
  2134. cos_out = math.cos(math.pi * pct) + 1
  2135. return end + (start - end) / 2.0 * cos_out
  2136. @staticmethod
  2137. def _annealing_linear(start, end, pct):
  2138. """Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
  2139. return (end - start) * pct + start
  2140. @override
  2141. def get_lr(self) -> list[float | Tensor]:
  2142. r"""Compute the next learning rate for each of the optimizer's
  2143. :attr:`~torch.optim.Optimizer.param_groups`.
  2144. Finds the appropriate :attr:`_schedule_phases` entry for the current
  2145. step and interpolates between its ``start_lr`` and ``end_lr`` using
  2146. :meth:`_anneal_func`.
  2147. Returns:
  2148. list[float | Tensor]: A :class:`list` of learning rates for each of
  2149. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  2150. same types as their current ``group["lr"]``\s.
  2151. .. note::
  2152. If you're trying to inspect the most recent learning rate, use
  2153. :meth:`get_last_lr()` instead.
  2154. .. note::
  2155. The returned :class:`~torch.Tensor`\s are copies, and never alias
  2156. the optimizer's ``group["lr"]``\s.
  2157. .. note::
  2158. When :attr:`cycle_momentum` is ``True``, this method has a side
  2159. effect of updating the optimizer's momentum.
  2160. """
  2161. _warn_get_lr_called_within_step(self)
  2162. lrs = []
  2163. step_num = self.last_epoch
  2164. if step_num > self.total_steps:
  2165. raise ValueError(
  2166. f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}"
  2167. )
  2168. for group in self.optimizer.param_groups:
  2169. start_step = 0.0
  2170. for i, phase in enumerate(self._schedule_phases):
  2171. end_step = phase["end_step"]
  2172. if step_num <= end_step or i == len(self._schedule_phases) - 1:
  2173. pct = (step_num - start_step) / (end_step - start_step)
  2174. computed_lr = self._anneal_func(
  2175. group[phase["start_lr"]], group[phase["end_lr"]], pct
  2176. )
  2177. if self.cycle_momentum:
  2178. computed_momentum = self._anneal_func(
  2179. group[phase["start_momentum"]],
  2180. group[phase["end_momentum"]],
  2181. pct,
  2182. )
  2183. break
  2184. start_step = phase["end_step"]
  2185. lrs.append(computed_lr) # type: ignore[possibly-undefined]
  2186. if self.cycle_momentum:
  2187. if self.use_beta1:
  2188. group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined]
  2189. else:
  2190. group["momentum"] = computed_momentum # type: ignore[possibly-undefined]
  2191. return lrs