optimization.py 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342
  1. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch optimization for BERT model."""
  15. from __future__ import annotations
  16. import math
  17. import warnings
  18. from functools import partial
  19. from typing import Any
  20. import torch
  21. from torch.optim import Optimizer
  22. from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
  23. from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
  24. from .trainer_utils import SchedulerType
  25. from .utils import logging
  26. logger = logging.get_logger(__name__)
  27. def _get_constant_lambda(_=None):
  28. return 1
  29. def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
  30. """
  31. Create a schedule with a constant learning rate, using the learning rate set in optimizer.
  32. Args:
  33. optimizer ([`~torch.optim.Optimizer`]):
  34. The optimizer for which to schedule the learning rate.
  35. last_epoch (`int`, *optional*, defaults to -1):
  36. The index of the last epoch when resuming training.
  37. Return:
  38. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  39. """
  40. return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)
  41. def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs):
  42. """
  43. Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
  44. Args:
  45. optimizer ([`~torch.optim.Optimizer`]):
  46. The optimizer for which to schedule the learning rate.
  47. kwargs (`dict`, *optional*):
  48. Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`
  49. for possible parameters.
  50. Return:
  51. `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
  52. """
  53. return ReduceLROnPlateau(optimizer, **kwargs)
  54. def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
  55. if current_step < num_warmup_steps:
  56. return float(current_step) / float(max(1.0, num_warmup_steps))
  57. return 1.0
  58. def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
  59. """
  60. Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
  61. increases linearly between 0 and the initial lr set in the optimizer.
  62. Args:
  63. optimizer ([`~torch.optim.Optimizer`]):
  64. The optimizer for which to schedule the learning rate.
  65. num_warmup_steps (`int`):
  66. The number of steps for the warmup phase.
  67. last_epoch (`int`, *optional*, defaults to -1):
  68. The index of the last epoch when resuming training.
  69. Return:
  70. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  71. """
  72. lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps)
  73. return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
  74. def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
  75. if current_step < num_warmup_steps:
  76. return float(current_step) / float(max(1, num_warmup_steps))
  77. return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
  78. def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
  79. """
  80. Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
  81. a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
  82. Args:
  83. optimizer ([`~torch.optim.Optimizer`]):
  84. The optimizer for which to schedule the learning rate.
  85. num_warmup_steps (`int`):
  86. The number of steps for the warmup phase.
  87. num_training_steps (`int`):
  88. The total number of training steps.
  89. last_epoch (`int`, *optional*, defaults to -1):
  90. The index of the last epoch when resuming training.
  91. Return:
  92. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  93. """
  94. lr_lambda = partial(
  95. _get_linear_schedule_with_warmup_lr_lambda,
  96. num_warmup_steps=num_warmup_steps,
  97. num_training_steps=num_training_steps,
  98. )
  99. return LambdaLR(optimizer, lr_lambda, last_epoch)
  100. def _get_cosine_schedule_with_warmup_lr_lambda(
  101. current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
  102. ):
  103. if current_step < num_warmup_steps:
  104. return float(current_step) / float(max(1, num_warmup_steps))
  105. progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
  106. return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
  107. def get_cosine_schedule_with_warmup(
  108. optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
  109. ):
  110. """
  111. Create a schedule with a learning rate that decreases following the values of the cosine function between the
  112. initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
  113. initial lr set in the optimizer.
  114. Args:
  115. optimizer ([`~torch.optim.Optimizer`]):
  116. The optimizer for which to schedule the learning rate.
  117. num_warmup_steps (`int`):
  118. The number of steps for the warmup phase.
  119. num_training_steps (`int`):
  120. The total number of training steps.
  121. num_cycles (`float`, *optional*, defaults to 0.5):
  122. The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
  123. following a half-cosine).
  124. last_epoch (`int`, *optional*, defaults to -1):
  125. The index of the last epoch when resuming training.
  126. Return:
  127. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  128. """
  129. lr_lambda = partial(
  130. _get_cosine_schedule_with_warmup_lr_lambda,
  131. num_warmup_steps=num_warmup_steps,
  132. num_training_steps=num_training_steps,
  133. num_cycles=num_cycles,
  134. )
  135. return LambdaLR(optimizer, lr_lambda, last_epoch)
  136. def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda(
  137. current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int
  138. ):
  139. if current_step < num_warmup_steps:
  140. return float(current_step) / float(max(1, num_warmup_steps))
  141. progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
  142. if progress >= 1.0:
  143. return 0.0
  144. return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
  145. def get_cosine_with_hard_restarts_schedule_with_warmup(
  146. optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
  147. ):
  148. """
  149. Create a schedule with a learning rate that decreases following the values of the cosine function between the
  150. initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
  151. linearly between 0 and the initial lr set in the optimizer.
  152. Args:
  153. optimizer ([`~torch.optim.Optimizer`]):
  154. The optimizer for which to schedule the learning rate.
  155. num_warmup_steps (`int`):
  156. The number of steps for the warmup phase.
  157. num_training_steps (`int`):
  158. The total number of training steps.
  159. num_cycles (`int`, *optional*, defaults to 1):
  160. The number of hard restarts to use.
  161. last_epoch (`int`, *optional*, defaults to -1):
  162. The index of the last epoch when resuming training.
  163. Return:
  164. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  165. """
  166. lr_lambda = partial(
  167. _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda,
  168. num_warmup_steps=num_warmup_steps,
  169. num_training_steps=num_training_steps,
  170. num_cycles=num_cycles,
  171. )
  172. return LambdaLR(optimizer, lr_lambda, last_epoch)
  173. def _get_polynomial_decay_schedule_with_warmup_lr_lambda(
  174. current_step: int,
  175. *,
  176. num_warmup_steps: int,
  177. num_training_steps: int,
  178. lr_end: float,
  179. power: float,
  180. lr_init: int,
  181. ):
  182. if current_step < num_warmup_steps:
  183. return float(current_step) / float(max(1, num_warmup_steps))
  184. elif current_step > num_training_steps:
  185. return lr_end / lr_init # as LambdaLR multiplies by lr_init
  186. else:
  187. lr_range = lr_init - lr_end
  188. decay_steps = num_training_steps - num_warmup_steps
  189. pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
  190. decay = lr_range * pct_remaining**power + lr_end
  191. return decay / lr_init # as LambdaLR multiplies by lr_init
  192. def get_polynomial_decay_schedule_with_warmup(
  193. optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
  194. ):
  195. """
  196. Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
  197. optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
  198. initial lr set in the optimizer.
  199. Args:
  200. optimizer ([`~torch.optim.Optimizer`]):
  201. The optimizer for which to schedule the learning rate.
  202. num_warmup_steps (`int`):
  203. The number of steps for the warmup phase.
  204. num_training_steps (`int`):
  205. The total number of training steps.
  206. lr_end (`float`, *optional*, defaults to 1e-7):
  207. The end LR.
  208. power (`float`, *optional*, defaults to 1.0):
  209. Power factor.
  210. last_epoch (`int`, *optional*, defaults to -1):
  211. The index of the last epoch when resuming training.
  212. Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
  213. implementation at
  214. https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
  215. Return:
  216. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  217. """
  218. lr_init = optimizer.defaults["lr"]
  219. if not (lr_init > lr_end):
  220. raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
  221. lr_lambda = partial(
  222. _get_polynomial_decay_schedule_with_warmup_lr_lambda,
  223. num_warmup_steps=num_warmup_steps,
  224. num_training_steps=num_training_steps,
  225. lr_end=lr_end,
  226. power=power,
  227. lr_init=lr_init,
  228. )
  229. return LambdaLR(optimizer, lr_lambda, last_epoch)
  230. def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int | None = None):
  231. if current_step < num_warmup_steps:
  232. return float(current_step) / float(max(1, num_warmup_steps))
  233. shift = timescale - num_warmup_steps
  234. decay = 1.0 / math.sqrt((current_step + shift) / timescale)
  235. return decay
  236. def get_inverse_sqrt_schedule(
  237. optimizer: Optimizer, num_warmup_steps: int, timescale: int | None = None, last_epoch: int = -1
  238. ):
  239. """
  240. Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
  241. warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
  242. Args:
  243. optimizer ([`~torch.optim.Optimizer`]):
  244. The optimizer for which to schedule the learning rate.
  245. num_warmup_steps (`int`):
  246. The number of steps for the warmup phase.
  247. timescale (`int`, *optional*, defaults to `num_warmup_steps`):
  248. Time scale.
  249. last_epoch (`int`, *optional*, defaults to -1):
  250. The index of the last epoch when resuming training.
  251. Return:
  252. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  253. """
  254. # Note: this implementation is adapted from
  255. # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
  256. if timescale is None:
  257. timescale = num_warmup_steps or 10_000
  258. lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale)
  259. return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
  260. def _get_cosine_schedule_with_warmup_lr_lambda(
  261. current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
  262. ):
  263. if current_step < num_warmup_steps:
  264. return float(current_step) / float(max(1, num_warmup_steps))
  265. progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
  266. factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
  267. factor = factor * (1 - min_lr_rate) + min_lr_rate
  268. return max(0, factor)
  269. def get_cosine_with_min_lr_schedule_with_warmup(
  270. optimizer: Optimizer,
  271. num_warmup_steps: int,
  272. num_training_steps: int,
  273. num_cycles: float = 0.5,
  274. last_epoch: int = -1,
  275. min_lr: float | None = None,
  276. min_lr_rate: float | None = None,
  277. ):
  278. """
  279. Create a schedule with a learning rate that decreases following the values of the cosine function between the
  280. initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
  281. initial lr set in the optimizer.
  282. Args:
  283. optimizer ([`~torch.optim.Optimizer`]):
  284. The optimizer for which to schedule the learning rate.
  285. num_warmup_steps (`int`):
  286. The number of steps for the warmup phase.
  287. num_training_steps (`int`):
  288. The total number of training steps.
  289. num_cycles (`float`, *optional*, defaults to 0.5):
  290. The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
  291. following a half-cosine).
  292. last_epoch (`int`, *optional*, defaults to -1):
  293. The index of the last epoch when resuming training.
  294. min_lr (`float`, *optional*):
  295. The minimum learning rate to reach after the cosine schedule.
  296. min_lr_rate (`float`, *optional*):
  297. The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
  298. Return:
  299. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  300. """
  301. if min_lr is not None and min_lr_rate is not None:
  302. raise ValueError("Only one of min_lr or min_lr_rate should be set")
  303. elif min_lr is not None:
  304. min_lr_rate = min_lr / optimizer.defaults["lr"]
  305. elif min_lr_rate is None:
  306. raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
  307. lr_lambda = partial(
  308. _get_cosine_schedule_with_warmup_lr_lambda,
  309. num_warmup_steps=num_warmup_steps,
  310. num_training_steps=num_training_steps,
  311. num_cycles=num_cycles,
  312. min_lr_rate=min_lr_rate,
  313. )
  314. return LambdaLR(optimizer, lr_lambda, last_epoch)
  315. def _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda(
  316. current_step: int,
  317. *,
  318. num_warmup_steps: int,
  319. num_training_steps: int,
  320. num_cycles: float,
  321. min_lr_rate: float = 0.0,
  322. warmup_lr_rate: float | None = None,
  323. ):
  324. current_step = float(current_step)
  325. num_warmup_steps = float(num_warmup_steps)
  326. num_training_steps = float(num_training_steps)
  327. if current_step < num_warmup_steps:
  328. if warmup_lr_rate is None:
  329. return (current_step + 1.0) / max(1.0, num_warmup_steps)
  330. else:
  331. warmup_lr_rate = float(warmup_lr_rate)
  332. return warmup_lr_rate + (1.0 - warmup_lr_rate) * (current_step) / (max(1, num_warmup_steps - 1))
  333. progress = (current_step - num_warmup_steps + 1.0) / (max(1.0, num_training_steps - num_warmup_steps))
  334. factor = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress))
  335. factor = factor * (1 - min_lr_rate) + min_lr_rate
  336. return max(0, factor)
  337. def get_cosine_with_min_lr_schedule_with_warmup_lr_rate(
  338. optimizer: Optimizer,
  339. num_warmup_steps: int,
  340. num_training_steps: int,
  341. num_cycles: float = 0.5,
  342. last_epoch: int = -1,
  343. min_lr: float | None = None,
  344. min_lr_rate: float | None = None,
  345. warmup_lr_rate: float | None = None,
  346. ):
  347. """
  348. Create a schedule with a learning rate that decreases following the values of the cosine function between the
  349. initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
  350. initial lr set in the optimizer.
  351. Args:
  352. optimizer ([`~torch.optim.Optimizer`]):
  353. The optimizer for which to schedule the learning rate.
  354. num_warmup_steps (`int`):
  355. The number of steps for the warmup phase.
  356. num_training_steps (`int`):
  357. The total number of training steps.
  358. num_cycles (`float`, *optional*, defaults to 0.5):
  359. The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
  360. following a half-cosine).
  361. last_epoch (`int`, *optional*, defaults to -1):
  362. The index of the last epoch when resuming training.
  363. min_lr (`float`, *optional*):
  364. The minimum learning rate to reach after the cosine schedule.
  365. min_lr_rate (`float`, *optional*):
  366. The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
  367. warmup_lr_rate (`float`, *optional*):
  368. The minimum learning rate as a ratio of the start learning rate. If not set, `warmup_lr_rate` will be treated as float(1/num_warmup_steps).
  369. Return:
  370. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  371. """
  372. if min_lr is not None and min_lr_rate is not None:
  373. raise ValueError("Only one of min_lr or min_lr_rate should be set")
  374. elif min_lr is not None:
  375. min_lr_rate = min_lr / optimizer.defaults["lr"]
  376. elif min_lr_rate is None:
  377. raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
  378. lr_lambda = partial(
  379. _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda,
  380. num_warmup_steps=num_warmup_steps,
  381. num_training_steps=num_training_steps,
  382. num_cycles=num_cycles,
  383. min_lr_rate=min_lr_rate,
  384. warmup_lr_rate=warmup_lr_rate,
  385. )
  386. return LambdaLR(optimizer, lr_lambda, last_epoch)
  387. def _get_wsd_scheduler_lambda(
  388. current_step: int,
  389. *,
  390. num_warmup_steps: int,
  391. num_stable_steps: int,
  392. num_decay_steps: int,
  393. warmup_type: str,
  394. decay_type: str,
  395. min_lr_ratio: float,
  396. num_cycles: float,
  397. ):
  398. if current_step < num_warmup_steps:
  399. progress = float(current_step) / float(max(1, num_warmup_steps))
  400. if warmup_type == "linear":
  401. factor = progress
  402. elif warmup_type == "cosine":
  403. factor = 0.5 * (1.0 - math.cos(math.pi * progress))
  404. elif warmup_type == "1-sqrt":
  405. factor = 1.0 - math.sqrt(1.0 - progress)
  406. factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
  407. return max(0.0, factor)
  408. if current_step < num_warmup_steps + num_stable_steps:
  409. return 1.0
  410. if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
  411. progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
  412. if decay_type == "linear":
  413. factor = 1.0 - progress
  414. elif decay_type == "cosine":
  415. factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
  416. elif decay_type == "1-sqrt":
  417. factor = 1.0 - math.sqrt(progress)
  418. factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
  419. return max(0.0, factor)
  420. return min_lr_ratio
  421. def get_wsd_schedule(
  422. optimizer: Optimizer,
  423. num_warmup_steps: int,
  424. num_decay_steps: int,
  425. num_training_steps: int | None = None,
  426. num_stable_steps: int | None = None,
  427. warmup_type: str = "linear",
  428. decay_type: str = "cosine",
  429. min_lr_ratio: float = 0,
  430. num_cycles: float = 0.5,
  431. last_epoch: int = -1,
  432. ):
  433. """
  434. Create a schedule with a learning rate that has three stages:
  435. 1. warmup: increase from min_lr_ratio times the initial learning rate to the initial learning rate following a warmup_type.
  436. 2. stable: constant learning rate.
  437. 3. decay: decrease from the initial learning rate to min_lr_ratio times the initial learning rate following a decay_type.
  438. Args:
  439. optimizer ([`~torch.optim.Optimizer`]):
  440. The optimizer for which to schedule the learning rate.
  441. num_warmup_steps (`int`):
  442. The number of steps for the warmup phase.
  443. num_decay_steps (`int`):
  444. The number of steps for the decay phase.
  445. num_training_steps (`int`, *optional*):
  446. The total number of training steps. This is the sum of the warmup, stable and decay steps. If `num_stable_steps` is not provided, the stable phase will be `num_training_steps - num_warmup_steps - num_decay_steps`.
  447. num_stable_steps (`int`, *optional*):
  448. The number of steps for the stable phase. Please ensure that `num_warmup_steps + num_stable_steps + num_decay_steps` equals `num_training_steps`, otherwise the other steps will default to the minimum learning rate.
  449. warmup_type (`str`, *optional*, defaults to "linear"):
  450. The type of warmup to use. Can be 'linear', 'cosine' or '1-sqrt'.
  451. decay_type (`str`, *optional*, defaults to "cosine"):
  452. The type of decay to use. Can be 'linear', 'cosine' or '1-sqrt'.
  453. min_lr_ratio (`float`, *optional*, defaults to 0):
  454. The minimum learning rate as a ratio of the initial learning rate.
  455. num_cycles (`float`, *optional*, defaults to 0.5):
  456. The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
  457. following a half-cosine).
  458. last_epoch (`int`, *optional*, defaults to -1):
  459. The index of the last epoch when resuming training.
  460. Return:
  461. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  462. """
  463. if num_training_steps is None and num_stable_steps is None:
  464. raise ValueError("Either num_training_steps or num_stable_steps must be specified.")
  465. if num_training_steps is not None and num_stable_steps is not None:
  466. warnings.warn("Both num_training_steps and num_stable_steps are specified. num_stable_steps will be used.")
  467. if warmup_type not in ["linear", "cosine", "1-sqrt"]:
  468. raise ValueError(f"Unknown warmup type: {warmup_type}, expected 'linear', 'cosine' or '1-sqrt'")
  469. if decay_type not in ["linear", "cosine", "1-sqrt"]:
  470. raise ValueError(f"Unknown decay type: {decay_type}, expected 'linear', 'cosine' or '1-sqrt'")
  471. if num_stable_steps is None:
  472. num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
  473. lr_lambda = partial(
  474. _get_wsd_scheduler_lambda,
  475. num_warmup_steps=num_warmup_steps,
  476. num_stable_steps=num_stable_steps,
  477. num_decay_steps=num_decay_steps,
  478. warmup_type=warmup_type,
  479. decay_type=decay_type,
  480. min_lr_ratio=min_lr_ratio,
  481. num_cycles=num_cycles,
  482. )
  483. return LambdaLR(optimizer, lr_lambda, last_epoch)
  484. class StreamingAverage:
  485. """Rolling window average for smoothing metric values.
  486. Maintains a sliding window of values and computes their average,
  487. useful for smoothing noisy metric values before making learning rate decisions.
  488. Args:
  489. window_size (`int`):
  490. The maximum number of values to keep in the rolling window.
  491. """
  492. def __init__(self, window_size: int) -> None:
  493. self.window_size: int = window_size
  494. self.values: list[float] = []
  495. self.sum: float = 0.0
  496. def streamavg(self, value: float) -> float:
  497. """Add a value and return the current rolling average."""
  498. self.values.append(value)
  499. self.sum += value
  500. if len(self.values) > self.window_size:
  501. removed = self.values.pop(0)
  502. self.sum -= removed
  503. return self.sum / len(self.values)
  504. def state_dict(self) -> dict[str, Any]:
  505. return {
  506. "window_size": self.window_size,
  507. "values": self.values.copy(),
  508. "sum": self.sum,
  509. }
  510. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  511. self.window_size = state_dict.get("window_size", self.window_size)
  512. self.values = state_dict.get("values", []).copy()
  513. self.sum = state_dict.get("sum", 0.0)
  514. class GreedyLR:
  515. """Adaptive learning rate scheduler that responds to training metrics.
  516. GreedyLR dynamically adjusts the learning rate based on training performance:
  517. - Increases LR when metrics improve consistently (divides by factor)
  518. - Decreases LR when metrics plateau (multiplies by factor)
  519. This differs from traditional schedulers like cosine annealing by responding
  520. to actual training dynamics rather than following a predetermined schedule.
  521. Reference: `GreedyLR: A Novel Adaptive Learning Rate Scheduler <https://arxiv.org/abs/2512.14527>`_
  522. Args:
  523. optimizer ([`~torch.optim.Optimizer`]):
  524. The optimizer for which to schedule the learning rate.
  525. mode (`str`, *optional*, defaults to `"min"`):
  526. One of 'min' or 'max'. In 'min' mode, LR will be reduced when the
  527. metric has stopped decreasing; in 'max' mode when it has stopped increasing.
  528. factor (`float`, *optional*, defaults to 0.95):
  529. Factor by which the learning rate will be adjusted. LR is multiplied by
  530. factor on plateau and divided by factor on improvement. Must be < 1.0.
  531. patience (`int`, *optional*, defaults to 10):
  532. Number of epochs with no improvement after which learning rate will be adjusted.
  533. threshold (`float`, *optional*, defaults to 1e-06):
  534. Threshold for measuring the new optimum.
  535. threshold_mode (`str`, *optional*, defaults to `"abs"`):
  536. One of 'rel' or 'abs'.
  537. cooldown (`int`, *optional*, defaults to 0):
  538. Number of epochs to wait before resuming normal operation after LR has been reduced.
  539. warmup (`int`, *optional*, defaults to 0):
  540. Number of epochs to wait before resuming normal operation after LR has been increased.
  541. min_lr (`float` or `list[float]`, *optional*, defaults to 0.001):
  542. A lower bound on the learning rate.
  543. max_lr (`float` or `list[float]`, *optional*, defaults to 1.0):
  544. An upper bound on the learning rate.
  545. eps (`float`, *optional*, defaults to 1e-08):
  546. Minimal decay applied to lr.
  547. verbose (`bool`, *optional*, defaults to `False`):
  548. If True, prints a message to stdout for each update.
  549. smooth (`bool`, *optional*, defaults to `False`):
  550. If True, applies streaming average smoothing to metrics.
  551. window_size (`int`, *optional*, defaults to 50):
  552. The window size for the streaming average when smooth=True.
  553. reset_start (`int`, *optional*, defaults to 500):
  554. Number of steps to wait at min_lr before resetting to initial state.
  555. Example:
  556. ```python
  557. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  558. >>> scheduler = GreedyLR(optimizer, mode="min", patience=10)
  559. >>> for epoch in range(100):
  560. ... train(...)
  561. ... val_loss = validate(...)
  562. ... scheduler.step(val_loss)
  563. ```
  564. """
  565. def __init__(
  566. self,
  567. optimizer: Optimizer,
  568. mode: str = "min",
  569. factor: float = 0.95,
  570. patience: int = 10,
  571. threshold: float = 1e-6,
  572. threshold_mode: str = "abs",
  573. cooldown: int = 0,
  574. warmup: int = 0,
  575. min_lr: float | list[float] = 1e-3,
  576. max_lr: float | list[float] = 1.0,
  577. eps: float = 1e-8,
  578. verbose: bool = False,
  579. smooth: bool = False,
  580. window_size: int = 50,
  581. reset_start: int = 500,
  582. ) -> None:
  583. if factor >= 1.0:
  584. raise ValueError("Factor should be < 1.0.")
  585. if not isinstance(optimizer, Optimizer):
  586. raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
  587. self.optimizer = optimizer
  588. self.factor = factor
  589. self.patience = patience
  590. self.verbose = verbose
  591. self.cooldown = cooldown
  592. self.warmup = warmup
  593. self.cooldown_counter = 0
  594. self.warmup_counter = 0
  595. self.mode = mode
  596. self.threshold = threshold
  597. self.threshold_mode = threshold_mode
  598. self.eps = eps
  599. self.smooth = smooth
  600. self.window_size = window_size
  601. self.reset_start = reset_start
  602. self.reset_start_original = reset_start
  603. self.last_epoch = 0
  604. if isinstance(min_lr, (list, tuple)):
  605. if len(min_lr) != len(optimizer.param_groups):
  606. raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
  607. self.min_lrs = list(min_lr)
  608. else:
  609. self.min_lrs = [min_lr] * len(optimizer.param_groups)
  610. if isinstance(max_lr, (list, tuple)):
  611. if len(max_lr) != len(optimizer.param_groups):
  612. raise ValueError(f"expected {len(optimizer.param_groups)} max_lrs, got {len(max_lr)}")
  613. self.max_lrs = list(max_lr)
  614. else:
  615. self.max_lrs = [max_lr] * len(optimizer.param_groups)
  616. self._init_lrs = [group["lr"] for group in optimizer.param_groups]
  617. self._last_lr = self._init_lrs.copy()
  618. self.best: float = float("inf") if mode == "min" else float("-inf")
  619. self.num_bad_epochs = 0
  620. self.num_good_epochs = 0
  621. if mode not in ("min", "max"):
  622. raise ValueError(f"mode {mode} is unknown!")
  623. if threshold_mode not in ("rel", "abs"):
  624. raise ValueError(f"threshold mode {threshold_mode} is unknown!")
  625. self._streaming_avg: StreamingAverage | None = None
  626. if smooth:
  627. self._streaming_avg = StreamingAverage(window_size)
  628. def step(self, metrics: float, epoch: int | None = None) -> None:
  629. """Perform a scheduler step based on the given metrics.
  630. Args:
  631. metrics (`float`):
  632. The metric value to use for LR adjustment decisions.
  633. epoch (`int`, *optional*):
  634. The current epoch number. If None, uses internal counter.
  635. """
  636. current = float(metrics)
  637. if self.smooth and self._streaming_avg is not None:
  638. current = self._streaming_avg.streamavg(current)
  639. if epoch is None:
  640. epoch = self.last_epoch + 1
  641. self.last_epoch = epoch
  642. if self.cooldown_counter > 0:
  643. self.cooldown_counter -= 1
  644. self.num_bad_epochs = 0
  645. self.num_good_epochs = 0
  646. elif self.warmup_counter > 0:
  647. self.warmup_counter -= 1
  648. self.num_bad_epochs = 0
  649. self.num_good_epochs = 0
  650. else:
  651. if self.is_better(current, self.best):
  652. self.best = current
  653. self.num_bad_epochs = 0
  654. self.num_good_epochs += 1
  655. else:
  656. self.num_bad_epochs += 1
  657. self.num_good_epochs = 0
  658. if self.num_good_epochs > self.patience:
  659. self._increase_lr(epoch)
  660. self.warmup_counter = self.warmup
  661. self.num_good_epochs = 0
  662. elif self.num_bad_epochs > self.patience:
  663. self._reduce_lr(epoch)
  664. self.cooldown_counter = self.cooldown
  665. self.num_bad_epochs = 0
  666. self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
  667. def is_better(self, current: float, best: float) -> bool:
  668. if self.mode == "min":
  669. if self.threshold_mode == "rel":
  670. return current < best * (1.0 - self.threshold)
  671. else:
  672. return current < best - self.threshold
  673. else:
  674. if self.threshold_mode == "rel":
  675. return current > best * (1.0 + self.threshold)
  676. else:
  677. return current > best + self.threshold
  678. def _reduce_lr(self, epoch: int) -> None:
  679. all_at_min = True
  680. for i, param_group in enumerate(self.optimizer.param_groups):
  681. old_lr = float(param_group["lr"])
  682. new_lr = max(old_lr * self.factor, self.min_lrs[i])
  683. if old_lr - new_lr > self.eps:
  684. param_group["lr"] = new_lr
  685. if self.verbose:
  686. print(f"Epoch {epoch}: reducing learning rate of group {i} to {new_lr:.4e}.")
  687. if param_group["lr"] > self.min_lrs[i]:
  688. all_at_min = False
  689. if all_at_min:
  690. self.reset_start -= 1
  691. if self.reset_start <= 0:
  692. self._reset()
  693. def _increase_lr(self, epoch: int) -> None:
  694. for i, param_group in enumerate(self.optimizer.param_groups):
  695. old_lr = float(param_group["lr"])
  696. new_lr = min(old_lr / self.factor, self.max_lrs[i])
  697. if new_lr - old_lr > self.eps:
  698. param_group["lr"] = new_lr
  699. if self.verbose:
  700. print(f"Epoch {epoch}: increasing learning rate of group {i} to {new_lr:.4e}.")
  701. self.reset_start = self.reset_start_original
  702. def _reset(self) -> None:
  703. for i, param_group in enumerate(self.optimizer.param_groups):
  704. param_group["lr"] = self._init_lrs[i]
  705. self.best = float("inf") if self.mode == "min" else float("-inf")
  706. self.num_bad_epochs = 0
  707. self.num_good_epochs = 0
  708. self.cooldown_counter = 0
  709. self.warmup_counter = 0
  710. self.reset_start = self.reset_start_original
  711. if self.smooth and self._streaming_avg is not None:
  712. self._streaming_avg = StreamingAverage(self.window_size)
  713. if self.verbose:
  714. print("Scheduler reset to initial state.")
  715. def get_last_lr(self) -> list[float]:
  716. """Return last computed learning rate by current scheduler."""
  717. return self._last_lr
  718. def state_dict(self) -> dict[str, Any]:
  719. """Return the state of the scheduler as a dictionary."""
  720. state = {
  721. "factor": self.factor,
  722. "min_lrs": self.min_lrs,
  723. "max_lrs": self.max_lrs,
  724. "patience": self.patience,
  725. "verbose": self.verbose,
  726. "cooldown": self.cooldown,
  727. "warmup": self.warmup,
  728. "cooldown_counter": self.cooldown_counter,
  729. "warmup_counter": self.warmup_counter,
  730. "mode": self.mode,
  731. "threshold": self.threshold,
  732. "threshold_mode": self.threshold_mode,
  733. "best": self.best,
  734. "num_bad_epochs": self.num_bad_epochs,
  735. "num_good_epochs": self.num_good_epochs,
  736. "eps": self.eps,
  737. "last_epoch": self.last_epoch,
  738. "smooth": self.smooth,
  739. "window_size": self.window_size,
  740. "reset_start": self.reset_start,
  741. "reset_start_original": self.reset_start_original,
  742. "_last_lr": self._last_lr,
  743. "_init_lrs": self._init_lrs,
  744. }
  745. if self.smooth and self._streaming_avg is not None:
  746. state["_streaming_avg"] = self._streaming_avg.state_dict()
  747. return state
  748. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  749. """Load state from a dictionary."""
  750. self.factor = state_dict.get("factor", self.factor)
  751. self.min_lrs = state_dict.get("min_lrs", self.min_lrs)
  752. self.max_lrs = state_dict.get("max_lrs", self.max_lrs)
  753. self.patience = state_dict.get("patience", self.patience)
  754. self.verbose = state_dict.get("verbose", self.verbose)
  755. self.cooldown = state_dict.get("cooldown", self.cooldown)
  756. self.warmup = state_dict.get("warmup", self.warmup)
  757. self.cooldown_counter = state_dict.get("cooldown_counter", self.cooldown_counter)
  758. self.warmup_counter = state_dict.get("warmup_counter", self.warmup_counter)
  759. self.mode = state_dict.get("mode", self.mode)
  760. self.threshold = state_dict.get("threshold", self.threshold)
  761. self.threshold_mode = state_dict.get("threshold_mode", self.threshold_mode)
  762. self.best = state_dict.get("best", self.best)
  763. self.num_bad_epochs = state_dict.get("num_bad_epochs", self.num_bad_epochs)
  764. self.num_good_epochs = state_dict.get("num_good_epochs", self.num_good_epochs)
  765. self.eps = state_dict.get("eps", self.eps)
  766. self.last_epoch = state_dict.get("last_epoch", self.last_epoch)
  767. self.smooth = state_dict.get("smooth", self.smooth)
  768. self.window_size = state_dict.get("window_size", self.window_size)
  769. self.reset_start = state_dict.get("reset_start", self.reset_start)
  770. self.reset_start_original = state_dict.get("reset_start_original", self.reset_start_original)
  771. self._last_lr = state_dict.get("_last_lr", self._last_lr)
  772. self._init_lrs = state_dict.get("_init_lrs", self._init_lrs)
  773. if "_streaming_avg" in state_dict:
  774. if self._streaming_avg is None:
  775. self._streaming_avg = StreamingAverage(self.window_size)
  776. self._streaming_avg.load_state_dict(state_dict["_streaming_avg"])
  777. if "_last_lr" in state_dict:
  778. for param_group, lr in zip(self.optimizer.param_groups, self._last_lr):
  779. param_group["lr"] = lr
  780. def get_greedy_schedule(optimizer: Optimizer, **kwargs):
  781. """
  782. Create an adaptive learning rate scheduler that adjusts LR based on training metrics.
  783. Args:
  784. optimizer ([`~torch.optim.Optimizer`]):
  785. The optimizer for which to schedule the learning rate.
  786. kwargs (`dict`, *optional*):
  787. Extra parameters passed to the scheduler. See [`GreedyLR`] for possible parameters.
  788. Return:
  789. [`GreedyLR`] with the appropriate schedule.
  790. """
  791. return GreedyLR(optimizer, **kwargs)
  792. TYPE_TO_SCHEDULER_FUNCTION = {
  793. SchedulerType.LINEAR: get_linear_schedule_with_warmup,
  794. SchedulerType.COSINE: get_cosine_schedule_with_warmup,
  795. SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
  796. SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
  797. SchedulerType.CONSTANT: get_constant_schedule,
  798. SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
  799. SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
  800. SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
  801. SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup,
  802. SchedulerType.COSINE_WARMUP_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup_lr_rate,
  803. SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule,
  804. SchedulerType.GREEDY: get_greedy_schedule,
  805. }
  806. def get_scheduler(
  807. name: str | SchedulerType,
  808. optimizer: Optimizer,
  809. num_warmup_steps: int | None = None,
  810. num_training_steps: int | None = None,
  811. scheduler_specific_kwargs: dict | None = None,
  812. ):
  813. """
  814. Unified API to get any scheduler from its name.
  815. Args:
  816. name (`str` or `SchedulerType`):
  817. The name of the scheduler to use.
  818. optimizer (`torch.optim.Optimizer`):
  819. The optimizer that will be used during training.
  820. num_warmup_steps (`int`, *optional*):
  821. The number of warmup steps to do. This is not required by all schedulers (hence the argument being
  822. optional), the function will raise an error if it's unset and the scheduler type requires it.
  823. num_training_steps (`int``, *optional*):
  824. The number of training steps to do. This is not required by all schedulers (hence the argument being
  825. optional), the function will raise an error if it's unset and the scheduler type requires it.
  826. scheduler_specific_kwargs (`dict`, *optional*):
  827. Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
  828. parameters will cause the scheduler function to raise a TypeError.
  829. """
  830. name = SchedulerType(name)
  831. schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
  832. # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
  833. # recursively call `get_scheduler` to get the proper schedulers on each parameter
  834. if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
  835. optimizer_dict = optimizer.optimizer_dict
  836. scheduler_dict = {}
  837. for param in optimizer_dict:
  838. scheduler_dict[param] = get_scheduler(
  839. name,
  840. optimizer=optimizer_dict[param],
  841. num_warmup_steps=num_warmup_steps,
  842. num_training_steps=num_training_steps,
  843. scheduler_specific_kwargs=scheduler_specific_kwargs,
  844. )
  845. def scheduler_hook(param):
  846. # Since the optimizer hook has been already attached we only need to
  847. # attach the scheduler hook, the gradients have been zeroed here
  848. scheduler_dict[param].step()
  849. for param in optimizer_dict:
  850. if param.requires_grad:
  851. param.register_post_accumulate_grad_hook(scheduler_hook)
  852. return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"])
  853. if name == SchedulerType.CONSTANT:
  854. return schedule_func(optimizer)
  855. if scheduler_specific_kwargs is None:
  856. scheduler_specific_kwargs = {}
  857. if name == SchedulerType.REDUCE_ON_PLATEAU:
  858. return schedule_func(optimizer, **scheduler_specific_kwargs)
  859. if name == SchedulerType.GREEDY:
  860. return schedule_func(optimizer, **scheduler_specific_kwargs)
  861. # All other schedulers require `num_warmup_steps`
  862. if num_warmup_steps is None:
  863. raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
  864. if name == SchedulerType.CONSTANT_WITH_WARMUP:
  865. return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
  866. if name == SchedulerType.INVERSE_SQRT:
  867. return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **scheduler_specific_kwargs)
  868. # wsd scheduler requires either num_training_steps or num_stable_steps
  869. if name == SchedulerType.WARMUP_STABLE_DECAY:
  870. return schedule_func(
  871. optimizer,
  872. num_warmup_steps=num_warmup_steps,
  873. num_training_steps=num_training_steps,
  874. **scheduler_specific_kwargs,
  875. )
  876. # All other schedulers require `num_training_steps`
  877. if num_training_steps is None:
  878. raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
  879. return schedule_func(
  880. optimizer,
  881. num_warmup_steps=num_warmup_steps,
  882. num_training_steps=num_training_steps,
  883. **scheduler_specific_kwargs,
  884. )
  885. class Adafactor(Optimizer):
  886. """
  887. AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
  888. https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
  889. Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://huggingface.co/papers/1804.04235 Note that
  890. this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
  891. `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
  892. `relative_step=False`.
  893. Arguments:
  894. params (`Iterable[nn.parameter.Parameter]`):
  895. Iterable of parameters to optimize or dictionaries defining parameter groups.
  896. lr (`float`, *optional*):
  897. The external learning rate.
  898. eps (`tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
  899. Regularization constants for square gradient and parameter scale respectively
  900. clip_threshold (`float`, *optional*, defaults to 1.0):
  901. Threshold of root mean square of final gradient update
  902. decay_rate (`float`, *optional*, defaults to -0.8):
  903. Coefficient used to compute running averages of square
  904. beta1 (`float`, *optional*):
  905. Coefficient used for computing running averages of gradient
  906. weight_decay (`float`, *optional*, defaults to 0.0):
  907. Weight decay (L2 penalty)
  908. scale_parameter (`bool`, *optional*, defaults to `True`):
  909. If True, learning rate is scaled by root mean square
  910. relative_step (`bool`, *optional*, defaults to `True`):
  911. If True, time-dependent learning rate is computed instead of external learning rate
  912. warmup_init (`bool`, *optional*, defaults to `False`):
  913. Time-dependent learning rate computation depends on whether warm-up initialization is being used
  914. This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
  915. Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
  916. - Training without LR warmup or clip_threshold is not recommended.
  917. - use scheduled LR warm-up to fixed LR
  918. - use clip_threshold=1.0 (https://huggingface.co/papers/1804.04235)
  919. - Disable relative updates
  920. - Use scale_parameter=False
  921. - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
  922. Example:
  923. ```python
  924. Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
  925. ```
  926. Others reported the following combination to work well:
  927. ```python
  928. Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
  929. ```
  930. When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
  931. scheduler as following:
  932. ```python
  933. from transformers.optimization import Adafactor, AdafactorSchedule
  934. optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
  935. lr_scheduler = AdafactorSchedule(optimizer)
  936. trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
  937. ```
  938. Usage:
  939. ```python
  940. # replace AdamW with Adafactor
  941. optimizer = Adafactor(
  942. model.parameters(),
  943. lr=1e-3,
  944. eps=(1e-30, 1e-3),
  945. clip_threshold=1.0,
  946. decay_rate=-0.8,
  947. beta1=None,
  948. weight_decay=0.0,
  949. relative_step=False,
  950. scale_parameter=False,
  951. warmup_init=False,
  952. )
  953. ```"""
  954. def __init__(
  955. self,
  956. params,
  957. lr=None,
  958. eps=(1e-30, 1e-3),
  959. clip_threshold=1.0,
  960. decay_rate=-0.8,
  961. beta1=None,
  962. weight_decay=0.0,
  963. scale_parameter=True,
  964. relative_step=True,
  965. warmup_init=False,
  966. ):
  967. if lr is not None and relative_step:
  968. raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
  969. if warmup_init and not relative_step:
  970. raise ValueError("`warmup_init=True` requires `relative_step=True`")
  971. defaults = {
  972. "lr": lr,
  973. "eps": eps,
  974. "clip_threshold": clip_threshold,
  975. "decay_rate": decay_rate,
  976. "beta1": beta1,
  977. "weight_decay": weight_decay,
  978. "scale_parameter": scale_parameter,
  979. "relative_step": relative_step,
  980. "warmup_init": warmup_init,
  981. }
  982. super().__init__(params, defaults)
  983. @staticmethod
  984. def _get_lr(param_group, param_state):
  985. rel_step_sz = param_group["lr"]
  986. if param_group["relative_step"]:
  987. min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
  988. rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
  989. param_scale = 1.0
  990. if param_group["scale_parameter"]:
  991. param_scale = max(param_group["eps"][1], param_state["RMS"])
  992. return param_scale * rel_step_sz
  993. @staticmethod
  994. def _get_options(param_group, param_shape):
  995. factored = len(param_shape) >= 2
  996. use_first_moment = param_group["beta1"] is not None
  997. return factored, use_first_moment
  998. @staticmethod
  999. def _rms(tensor):
  1000. return tensor.norm(2) / (tensor.numel() ** 0.5)
  1001. @staticmethod
  1002. def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
  1003. # copy from fairseq's adafactor implementation:
  1004. # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
  1005. r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
  1006. c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
  1007. return torch.mul(r_factor, c_factor)
  1008. @torch.no_grad()
  1009. def step(self, closure=None):
  1010. """
  1011. Performs a single optimization step
  1012. Arguments:
  1013. closure (callable, optional): A closure that reevaluates the model
  1014. and returns the loss.
  1015. """
  1016. loss = None
  1017. if closure is not None:
  1018. loss = closure()
  1019. for group in self.param_groups:
  1020. for p in group["params"]:
  1021. if p.grad is None:
  1022. continue
  1023. grad = p.grad
  1024. if grad.dtype in {torch.float16, torch.bfloat16}:
  1025. grad = grad.float()
  1026. if grad.is_sparse:
  1027. raise RuntimeError("Adafactor does not support sparse gradients.")
  1028. state = self.state[p]
  1029. grad_shape = grad.shape
  1030. factored, use_first_moment = self._get_options(group, grad_shape)
  1031. # State Initialization
  1032. if len(state) == 0:
  1033. state["step"] = 0
  1034. if use_first_moment:
  1035. # Exponential moving average of gradient values
  1036. state["exp_avg"] = torch.zeros_like(grad)
  1037. if factored:
  1038. state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
  1039. state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
  1040. else:
  1041. state["exp_avg_sq"] = torch.zeros_like(grad)
  1042. state["RMS"] = 0
  1043. else:
  1044. if use_first_moment:
  1045. state["exp_avg"] = state["exp_avg"].to(grad)
  1046. if factored:
  1047. state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
  1048. state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
  1049. else:
  1050. state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
  1051. p_data_fp32 = p
  1052. if p.dtype in {torch.float16, torch.bfloat16}:
  1053. p_data_fp32 = p_data_fp32.float()
  1054. state["step"] += 1
  1055. state["RMS"] = self._rms(p_data_fp32)
  1056. lr = self._get_lr(group, state)
  1057. beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
  1058. update = (grad**2) + group["eps"][0]
  1059. if factored:
  1060. exp_avg_sq_row = state["exp_avg_sq_row"]
  1061. exp_avg_sq_col = state["exp_avg_sq_col"]
  1062. exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
  1063. exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
  1064. # Approximation of exponential moving average of square of gradient
  1065. update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
  1066. update.mul_(grad)
  1067. else:
  1068. exp_avg_sq = state["exp_avg_sq"]
  1069. exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
  1070. update = exp_avg_sq.rsqrt().mul_(grad)
  1071. update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
  1072. update.mul_(lr)
  1073. if use_first_moment:
  1074. exp_avg = state["exp_avg"]
  1075. exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
  1076. update = exp_avg
  1077. if group["weight_decay"] != 0:
  1078. p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
  1079. p_data_fp32.add_(-update)
  1080. if p.dtype in {torch.float16, torch.bfloat16}:
  1081. p.copy_(p_data_fp32)
  1082. return loss
  1083. class AdafactorSchedule(LambdaLR):
  1084. """
  1085. Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g.,
  1086. for logging), this class creates a proxy object that retrieves the current lr values from the optimizer.
  1087. It returns `initial_lr` during startup and the actual `lr` during stepping.
  1088. """
  1089. def __init__(self, optimizer, initial_lr=0.0):
  1090. def lr_lambda(_):
  1091. return initial_lr
  1092. for group in optimizer.param_groups:
  1093. group["initial_lr"] = initial_lr
  1094. super().__init__(optimizer, lr_lambda)
  1095. for group in optimizer.param_groups:
  1096. del group["initial_lr"]
  1097. def get_lr(self):
  1098. opt = self.optimizer
  1099. lrs = [
  1100. opt._get_lr(group, opt.state[group["params"][0]])
  1101. for group in opt.param_groups
  1102. if group["params"][0].grad is not None
  1103. ]
  1104. if len(lrs) == 0:
  1105. lrs = self.base_lrs # if called before stepping
  1106. return lrs
  1107. def get_adafactor_schedule(optimizer, initial_lr=0.0):
  1108. """
  1109. Get a proxy schedule for [`~optimization.Adafactor`]
  1110. Args:
  1111. optimizer ([`~torch.optim.Optimizer`]):
  1112. The optimizer for which to schedule the learning rate.
  1113. initial_lr (`float`, *optional*, defaults to 0.0):
  1114. Initial lr
  1115. Return:
  1116. [`~optimization.Adafactor`] proxy schedule object.
  1117. """
  1118. return AdafactorSchedule(optimizer, initial_lr)