swa_utils.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. # mypy: allow-untyped-defs
  2. r"""Implementation for Stochastic Weight Averaging implementation."""
  3. import itertools
  4. import math
  5. import warnings
  6. from collections.abc import Callable, Iterable
  7. from copy import deepcopy
  8. from typing import Any, cast, Literal, Union
  9. from typing_extensions import override
  10. import torch
  11. from torch import Tensor
  12. from torch.nn import Module
  13. from torch.optim.lr_scheduler import _format_param, LRScheduler
  14. from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
  15. from .optimizer import Optimizer
  16. __all__ = [
  17. "AveragedModel",
  18. "update_bn",
  19. "SWALR",
  20. "get_ema_multi_avg_fn",
  21. "get_swa_multi_avg_fn",
  22. "get_ema_avg_fn",
  23. "get_swa_avg_fn",
  24. ]
  25. from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
  26. PARAM_LIST = Union[tuple[Tensor, ...], list[Tensor]]
  27. def get_ema_multi_avg_fn(decay=0.999):
  28. """Get the function applying exponential moving average (EMA) across multiple params.
  29. The EMA is computed as:
  30. .. math::
  31. W_0^{\\text{EMA}} = W_0^{\\text{model}}
  32. .. math::
  33. W_{t+1}^{\\text{EMA}} = \\text{decay} \\times W_t^{\\text{EMA}} + (1 - \\text{decay}) \\times W_{t+1}^{\\text{model}}
  34. where :math:`W_t^{\\text{EMA}}` is the EMA parameter at step :math:`t`,
  35. :math:`W_t^{\\text{model}}` is the model parameter at step :math:`t`,
  36. and :math:`\\text{decay}` is the decay rate (default: 0.999).
  37. Args:
  38. decay (float): Decay rate for EMA. Must be in the range [0, 1]. Default: 0.999
  39. Returns:
  40. Callable: A function that updates EMA parameters given current model parameters
  41. """
  42. if decay < 0.0 or decay > 1.0:
  43. raise ValueError(
  44. f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
  45. )
  46. @torch.no_grad()
  47. def ema_update(
  48. ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _
  49. ) -> None:
  50. # foreach lerp only handles float and complex
  51. if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(
  52. ema_param_list[0]
  53. ):
  54. torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
  55. else:
  56. for p_ema, p_model in zip(ema_param_list, current_param_list, strict=True):
  57. p_ema.copy_(p_ema * decay + p_model * (1 - decay))
  58. return ema_update
  59. def get_swa_multi_avg_fn():
  60. """Get the function applying stochastic weight average (SWA) across multiple params."""
  61. @torch.no_grad()
  62. def swa_update(
  63. averaged_param_list: PARAM_LIST,
  64. current_param_list: PARAM_LIST,
  65. num_averaged: Tensor | int,
  66. ) -> None:
  67. # foreach lerp only handles float and complex
  68. if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex(
  69. averaged_param_list[0]
  70. ):
  71. torch._foreach_lerp_(
  72. averaged_param_list,
  73. current_param_list,
  74. cast(float, 1 / (num_averaged + 1)),
  75. )
  76. else:
  77. diffs = torch._foreach_sub(current_param_list, averaged_param_list)
  78. if isinstance(num_averaged, Tensor):
  79. torch._foreach_addcdiv_(
  80. averaged_param_list,
  81. diffs,
  82. [num_averaged + 1] * len(averaged_param_list),
  83. )
  84. else:
  85. torch._foreach_add_(
  86. averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1)
  87. )
  88. return swa_update
  89. def get_ema_avg_fn(decay=0.999):
  90. """Get the function applying exponential moving average (EMA) across multiple params.
  91. The EMA is computed as:
  92. .. math::
  93. W_0^{\\text{EMA}} = W_0^{\\text{model}}
  94. .. math::
  95. W_{t+1}^{\\text{EMA}} = \\text{decay} \\times W_t^{\\text{EMA}} + (1 - \\text{decay}) \\times W_{t+1}^{\\text{model}}
  96. where :math:`W_t^{\\text{EMA}}` is the EMA parameter at step :math:`t`,
  97. :math:`W_t^{\\text{model}}` is the model parameter at step :math:`t`,
  98. and :math:`\\text{decay}` is the decay rate (default: 0.999).
  99. Args:
  100. decay (float): Decay rate for EMA. Must be in the range [0, 1]. Default: 0.999
  101. Returns:
  102. Callable: A function that updates EMA parameters given current model parameters
  103. """
  104. if decay < 0.0 or decay > 1.0:
  105. raise ValueError(
  106. f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
  107. )
  108. @torch.no_grad()
  109. def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
  110. return decay * ema_param + (1 - decay) * current_param
  111. return ema_update
  112. def get_swa_avg_fn():
  113. """Get the function applying stochastic weight average (SWA) across a single param."""
  114. @torch.no_grad()
  115. def swa_update(
  116. averaged_param: Tensor, current_param: Tensor, num_averaged: Tensor | int
  117. ):
  118. return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
  119. return swa_update
  120. class AveragedModel(Module):
  121. r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
  122. Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
  123. Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
  124. Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
  125. (UAI 2018).
  126. Exponential Moving Average is a variation of `Polyak averaging`_,
  127. but using exponential weights instead of equal weights across iterations.
  128. AveragedModel class creates a copy of the provided module :attr:`model`
  129. on the device :attr:`device` and allows to compute running averages of the
  130. parameters of the :attr:`model`.
  131. Args:
  132. model (torch.nn.Module): model to use with SWA/EMA
  133. device (torch.device, optional): if provided, the averaged model will be
  134. stored on the :attr:`device`
  135. avg_fn (function, optional): the averaging function used to update
  136. parameters; the function must take in the current value of the
  137. :class:`AveragedModel` parameter, the current value of :attr:`model`
  138. parameter, and the number of models already averaged; if None,
  139. an equally weighted average is used (default: None)
  140. multi_avg_fn (function, optional): the averaging function used to update
  141. parameters inplace; the function must take in the current values of the
  142. :class:`AveragedModel` parameters as a list, the current values of :attr:`model`
  143. parameters as a list, and the number of models already averaged; if None,
  144. an equally weighted average is used (default: None)
  145. use_buffers (bool): if ``True``, it will compute running averages for
  146. both the parameters and the buffers of the model. (default: ``False``)
  147. Example:
  148. >>> # xdoctest: +SKIP("undefined variables")
  149. >>> loader, optimizer, model, loss_fn = ...
  150. >>> swa_model = torch.optim.swa_utils.AveragedModel(model)
  151. >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
  152. >>> T_max=300)
  153. >>> swa_start = 160
  154. >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
  155. >>> for i in range(300):
  156. >>> for input, target in loader:
  157. >>> optimizer.zero_grad()
  158. >>> loss_fn(model(input), target).backward()
  159. >>> optimizer.step()
  160. >>> if i > swa_start:
  161. >>> swa_model.update_parameters(model)
  162. >>> swa_scheduler.step()
  163. >>> else:
  164. >>> scheduler.step()
  165. >>>
  166. >>> # Update bn statistics for the swa_model at the end
  167. >>> torch.optim.swa_utils.update_bn(loader, swa_model)
  168. You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters.
  169. If no averaging function is provided, the default is to compute
  170. equally-weighted average of the weights (SWA).
  171. Example:
  172. >>> # xdoctest: +SKIP("undefined variables")
  173. >>> # Compute exponential moving averages of the weights and buffers
  174. >>> ema_model = torch.optim.swa_utils.AveragedModel(model,
  175. >>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
  176. .. note::
  177. When using SWA/EMA with models containing Batch Normalization you may
  178. need to update the activation statistics for Batch Normalization.
  179. This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
  180. or by setting :attr:`use_buffers` to `True`. The first approach updates the
  181. statistics in a post-training step by passing data through the model. The
  182. second does it during the parameter update phase by averaging all buffers.
  183. Empirical evidence has shown that updating the statistics in normalization
  184. layers increases accuracy, but you may wish to empirically test which
  185. approach yields the best results in your problem.
  186. .. note::
  187. :attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model.
  188. .. note::
  189. When :meth:`update_parameters` is called for the first time (i.e.
  190. :attr:`n_averaged` is `0`) the parameters of `model` are copied
  191. to the parameters of :class:`AveragedModel`. For every subsequent
  192. call of :meth:`update_parameters` the function `avg_fn` is used
  193. to update the parameters.
  194. .. _Averaging Weights Leads to Wider Optima and Better Generalization:
  195. https://arxiv.org/abs/1803.05407
  196. .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should
  197. Average:
  198. https://arxiv.org/abs/1806.05594
  199. .. _SWALP: Stochastic Weight Averaging in Low-Precision Training:
  200. https://arxiv.org/abs/1904.11943
  201. .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That
  202. Generalizes Well:
  203. https://arxiv.org/abs/2001.02312
  204. .. _Polyak averaging:
  205. https://paperswithcode.com/method/polyak-averaging
  206. """
  207. n_averaged: Tensor
  208. def __init__(
  209. self,
  210. model: Module,
  211. device: int | torch.device | None = None,
  212. avg_fn: Callable[[Tensor, Tensor, Tensor | int], Tensor] | None = None,
  213. multi_avg_fn: Callable[[PARAM_LIST, PARAM_LIST, Tensor | int], None]
  214. | None = None,
  215. use_buffers=False,
  216. ) -> None: # noqa: D107
  217. super().__init__()
  218. if avg_fn is not None and multi_avg_fn is not None:
  219. raise AssertionError(
  220. "Only one of avg_fn and multi_avg_fn should be provided"
  221. )
  222. self.module = deepcopy(model)
  223. if device is not None:
  224. self.module = self.module.to(device)
  225. self.register_buffer(
  226. "n_averaged", torch.tensor(0, dtype=torch.long, device=device)
  227. )
  228. self.avg_fn = avg_fn
  229. self.multi_avg_fn = multi_avg_fn
  230. self.use_buffers = use_buffers
  231. def forward(self, *args, **kwargs):
  232. """Forward pass."""
  233. return self.module(*args, **kwargs)
  234. def update_parameters(self, model: Module) -> None:
  235. """Update model parameters."""
  236. self_param = (
  237. # pyrefly: ignore [bad-argument-type]
  238. itertools.chain(self.module.parameters(), self.module.buffers())
  239. if self.use_buffers
  240. else self.parameters()
  241. )
  242. model_param = (
  243. # pyrefly: ignore [bad-argument-type]
  244. itertools.chain(model.parameters(), model.buffers())
  245. if self.use_buffers
  246. else model.parameters()
  247. )
  248. self_param_detached: list[Tensor | None] = []
  249. model_param_detached: list[Tensor | None] = []
  250. copy_param = bool(self.n_averaged == 0)
  251. for p_averaged, p_model in zip(self_param, model_param, strict=False):
  252. p_model_ = p_model.detach().to(p_averaged.device)
  253. self_param_detached.append(p_averaged.detach())
  254. model_param_detached.append(p_model_)
  255. if copy_param:
  256. p_averaged.detach().copy_(p_model_)
  257. if self.n_averaged > 0:
  258. if self.multi_avg_fn is not None or self.avg_fn is None:
  259. grouped_tensors = _group_tensors_by_device_and_dtype(
  260. [self_param_detached, model_param_detached]
  261. )
  262. for (device, _), (
  263. [self_params, model_params],
  264. _,
  265. ) in grouped_tensors.items():
  266. if self.multi_avg_fn:
  267. self.multi_avg_fn(
  268. self_params, # type: ignore[arg-type]
  269. model_params, # type: ignore[arg-type]
  270. self.n_averaged.to(device),
  271. )
  272. elif (
  273. device is not None
  274. and device.type in _get_foreach_kernels_supported_devices()
  275. ):
  276. multi_avg_fn = get_swa_multi_avg_fn()
  277. multi_avg_fn(
  278. self_params, model_params, self.n_averaged.to(device)
  279. )
  280. else:
  281. avg_fn = get_swa_avg_fn()
  282. n_averaged = self.n_averaged.to(device)
  283. for p_averaged, p_model in zip( # type: ignore[assignment]
  284. self_params, model_params, strict=True
  285. ):
  286. # pyrefly: ignore [missing-attribute]
  287. p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
  288. else:
  289. for p_averaged, p_model in zip( # type: ignore[assignment]
  290. self_param_detached, model_param_detached, strict=True
  291. ):
  292. # pyrefly: ignore [missing-attribute]
  293. n_averaged = self.n_averaged.to(p_averaged.device)
  294. # pyrefly: ignore [missing-attribute]
  295. p_averaged.detach().copy_(
  296. # pyrefly: ignore [missing-attribute, bad-argument-type]
  297. self.avg_fn(p_averaged.detach(), p_model, n_averaged)
  298. )
  299. if not self.use_buffers:
  300. # If not apply running averages to the buffers,
  301. # keep the buffers in sync with the source model.
  302. for b_swa, b_model in zip(
  303. self.module.buffers(), model.buffers(), strict=True
  304. ):
  305. b_swa.detach().copy_(b_model.detach().to(b_swa.device))
  306. self.n_averaged += 1
  307. @torch.no_grad()
  308. def update_bn(
  309. loader: Iterable[Any],
  310. model: Module,
  311. device: int | torch.device | None = None,
  312. ) -> None:
  313. r"""Update BatchNorm running_mean, running_var buffers in the model.
  314. It performs one pass over data in `loader` to estimate the activation
  315. statistics for BatchNorm layers in the model.
  316. Args:
  317. loader (torch.utils.data.DataLoader): dataset loader to compute the
  318. activation statistics on. Each data batch should be either a
  319. tensor, or a list/tuple whose first element is a tensor
  320. containing data.
  321. model (torch.nn.Module): model for which we seek to update BatchNorm
  322. statistics.
  323. device (torch.device, optional): If set, data will be transferred to
  324. :attr:`device` before being passed into :attr:`model`.
  325. Example:
  326. >>> # xdoctest: +SKIP("Undefined variables")
  327. >>> loader, model = ...
  328. >>> torch.optim.swa_utils.update_bn(loader, model)
  329. .. note::
  330. The `update_bn` utility assumes that each data batch in :attr:`loader`
  331. is either a tensor or a list or tuple of tensors; in the latter case it
  332. is assumed that :meth:`model.forward()` should be called on the first
  333. element of the list or tuple corresponding to the data batch.
  334. """
  335. momenta = {}
  336. for module in model.modules():
  337. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  338. module.reset_running_stats()
  339. momenta[module] = module.momentum
  340. if not momenta:
  341. return
  342. was_training = model.training
  343. model.train()
  344. for module in momenta:
  345. module.momentum = None
  346. for input in loader:
  347. if isinstance(input, (list, tuple)):
  348. input = input[0]
  349. if device is not None:
  350. input = input.to(device)
  351. model(input)
  352. for bn_module in momenta:
  353. bn_module.momentum = momenta[bn_module]
  354. model.train(was_training)
  355. class SWALR(LRScheduler):
  356. r"""Anneals the learning rate in each parameter group to a fixed value.
  357. This learning rate scheduler is meant to be used with Stochastic Weight
  358. Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`).
  359. Args:
  360. optimizer (torch.optim.Optimizer): wrapped optimizer
  361. swa_lrs (float or list): the learning rate value for all param groups
  362. together or separately for each group.
  363. annealing_epochs (int): number of epochs in the annealing phase
  364. (default: 10)
  365. annealing_strategy (str): "cos" or "linear"; specifies the annealing
  366. strategy: "cos" for cosine annealing, "linear" for linear annealing
  367. (default: "cos")
  368. last_epoch (int): the index of the last epoch (default: -1)
  369. The :class:`SWALR` scheduler can be used together with other
  370. schedulers to switch to a constant learning rate late in the training
  371. as in the example below.
  372. Example:
  373. >>> # xdoctest: +SKIP("Undefined variables")
  374. >>> loader, optimizer, model = ...
  375. >>> lr_lambda = lambda epoch: 0.9
  376. >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
  377. >>> lr_lambda=lr_lambda)
  378. >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
  379. >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
  380. >>> swa_start = 160
  381. >>> for i in range(300):
  382. >>> for input, target in loader:
  383. >>> optimizer.zero_grad()
  384. >>> loss_fn(model(input), target).backward()
  385. >>> optimizer.step()
  386. >>> if i > swa_start:
  387. >>> swa_scheduler.step()
  388. >>> else:
  389. >>> scheduler.step()
  390. .. _Averaging Weights Leads to Wider Optima and Better Generalization:
  391. https://arxiv.org/abs/1803.05407
  392. """
  393. def __init__(
  394. self,
  395. optimizer: Optimizer,
  396. swa_lr: float,
  397. anneal_epochs=10,
  398. anneal_strategy: Literal["cos", "linear"] = "cos",
  399. last_epoch=-1,
  400. ) -> None: # noqa: D107
  401. swa_lrs = _format_param("swa_lr", optimizer, swa_lr)
  402. for swa_lr, group in zip(swa_lrs, optimizer.param_groups, strict=True):
  403. group["swa_lr"] = swa_lr
  404. if anneal_strategy not in ["cos", "linear"]:
  405. raise ValueError(
  406. "anneal_strategy must by one of 'cos' or 'linear', "
  407. f"instead got {anneal_strategy}"
  408. )
  409. self._set_anneal_func(anneal_strategy)
  410. if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
  411. raise ValueError(
  412. f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}"
  413. )
  414. self.anneal_epochs = anneal_epochs
  415. super().__init__(optimizer, last_epoch)
  416. @staticmethod
  417. def _linear_anneal(t):
  418. return t
  419. @staticmethod
  420. def _cosine_anneal(t):
  421. return (1 - math.cos(math.pi * t)) / 2
  422. @staticmethod
  423. def _get_initial_lr(lr, swa_lr, alpha):
  424. if alpha == 1:
  425. return swa_lr
  426. return (lr - alpha * swa_lr) / (1 - alpha)
  427. @override
  428. def get_lr(self):
  429. r"""Compute the next learning rate for each of the optimizer's
  430. :attr:`~torch.optim.Optimizer.param_groups`.
  431. Uses :attr:`anneal_func` to interpolate between each group's
  432. ``group["lr"]`` and ``group["swa_lr"]`` over :attr:`anneal_epochs`
  433. epochs. Once :attr:`anneal_epochs` is reached, keeps the learning rate
  434. fixed at ``group["swa_lr"]``.
  435. Returns:
  436. list[float | Tensor]: A :class:`list` of learning rates for each of
  437. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  438. same types as their current ``group["lr"]``\s.
  439. .. note::
  440. If you're trying to inspect the most recent learning rate, use
  441. :meth:`get_last_lr()` instead.
  442. .. note::
  443. The returned :class:`~torch.Tensor`\s are copies, and never alias
  444. the optimizer's ``group["lr"]``\s.
  445. """
  446. # `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
  447. # so we ignore the type error here. See `LRScheduler.step()` for more details.
  448. if not self._get_lr_called_within_step:
  449. warnings.warn(
  450. "To get the last learning rate computed by the scheduler, "
  451. "please use `get_last_lr()`.",
  452. UserWarning,
  453. stacklevel=2,
  454. )
  455. # Set in `LRScheduler._initial_step()`
  456. step = self._step_count - 1
  457. if self.anneal_epochs == 0:
  458. step = max(1, step)
  459. # pyrefly: ignore [no-matching-overload]
  460. prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
  461. prev_alpha = self.anneal_func(prev_t)
  462. prev_lrs = [
  463. self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
  464. for group in self.optimizer.param_groups
  465. ]
  466. # pyrefly: ignore [no-matching-overload]
  467. t = max(0, min(1, step / max(1, self.anneal_epochs)))
  468. alpha = self.anneal_func(t)
  469. return [
  470. group["swa_lr"] * alpha + lr * (1 - alpha)
  471. for group, lr in zip(self.optimizer.param_groups, prev_lrs, strict=True)
  472. ]
  473. def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]) -> None:
  474. self._anneal_strategy = anneal_strategy
  475. if anneal_strategy == "cos":
  476. self.anneal_func = self._cosine_anneal
  477. else:
  478. self.anneal_func = self._linear_anneal
  479. @override
  480. def state_dict(self) -> dict[str, Any]:
  481. """Return the state of the scheduler as a :class:`dict`.
  482. It contains an entry for every variable in self.__dict__ which
  483. is not the optimizer or anneal_func.
  484. """
  485. return {
  486. key: value
  487. for key, value in self.__dict__.items()
  488. if key not in ("optimizer", "anneal_func")
  489. }
  490. @override
  491. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  492. """Load the scheduler's state.
  493. Args:
  494. state_dict (dict): scheduler state. Should be an object returned
  495. from a call to :meth:`state_dict`.
  496. """
  497. self.__dict__.update(state_dict)
  498. self._set_anneal_func(self._anneal_strategy)