| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587 |
- # mypy: allow-untyped-defs
- r"""Implementation for Stochastic Weight Averaging implementation."""
- import itertools
- import math
- import warnings
- from collections.abc import Callable, Iterable
- from copy import deepcopy
- from typing import Any, cast, Literal, Union
- from typing_extensions import override
- import torch
- from torch import Tensor
- from torch.nn import Module
- from torch.optim.lr_scheduler import _format_param, LRScheduler
- from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
- from .optimizer import Optimizer
- __all__ = [
- "AveragedModel",
- "update_bn",
- "SWALR",
- "get_ema_multi_avg_fn",
- "get_swa_multi_avg_fn",
- "get_ema_avg_fn",
- "get_swa_avg_fn",
- ]
- from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
- PARAM_LIST = Union[tuple[Tensor, ...], list[Tensor]]
- def get_ema_multi_avg_fn(decay=0.999):
- """Get the function applying exponential moving average (EMA) across multiple params.
- The EMA is computed as:
- .. math::
- W_0^{\\text{EMA}} = W_0^{\\text{model}}
- .. math::
- W_{t+1}^{\\text{EMA}} = \\text{decay} \\times W_t^{\\text{EMA}} + (1 - \\text{decay}) \\times W_{t+1}^{\\text{model}}
- where :math:`W_t^{\\text{EMA}}` is the EMA parameter at step :math:`t`,
- :math:`W_t^{\\text{model}}` is the model parameter at step :math:`t`,
- and :math:`\\text{decay}` is the decay rate (default: 0.999).
- Args:
- decay (float): Decay rate for EMA. Must be in the range [0, 1]. Default: 0.999
- Returns:
- Callable: A function that updates EMA parameters given current model parameters
- """
- if decay < 0.0 or decay > 1.0:
- raise ValueError(
- f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
- )
- @torch.no_grad()
- def ema_update(
- ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _
- ) -> None:
- # foreach lerp only handles float and complex
- if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(
- ema_param_list[0]
- ):
- torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
- else:
- for p_ema, p_model in zip(ema_param_list, current_param_list, strict=True):
- p_ema.copy_(p_ema * decay + p_model * (1 - decay))
- return ema_update
- def get_swa_multi_avg_fn():
- """Get the function applying stochastic weight average (SWA) across multiple params."""
- @torch.no_grad()
- def swa_update(
- averaged_param_list: PARAM_LIST,
- current_param_list: PARAM_LIST,
- num_averaged: Tensor | int,
- ) -> None:
- # foreach lerp only handles float and complex
- if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex(
- averaged_param_list[0]
- ):
- torch._foreach_lerp_(
- averaged_param_list,
- current_param_list,
- cast(float, 1 / (num_averaged + 1)),
- )
- else:
- diffs = torch._foreach_sub(current_param_list, averaged_param_list)
- if isinstance(num_averaged, Tensor):
- torch._foreach_addcdiv_(
- averaged_param_list,
- diffs,
- [num_averaged + 1] * len(averaged_param_list),
- )
- else:
- torch._foreach_add_(
- averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1)
- )
- return swa_update
- def get_ema_avg_fn(decay=0.999):
- """Get the function applying exponential moving average (EMA) across multiple params.
- The EMA is computed as:
- .. math::
- W_0^{\\text{EMA}} = W_0^{\\text{model}}
- .. math::
- W_{t+1}^{\\text{EMA}} = \\text{decay} \\times W_t^{\\text{EMA}} + (1 - \\text{decay}) \\times W_{t+1}^{\\text{model}}
- where :math:`W_t^{\\text{EMA}}` is the EMA parameter at step :math:`t`,
- :math:`W_t^{\\text{model}}` is the model parameter at step :math:`t`,
- and :math:`\\text{decay}` is the decay rate (default: 0.999).
- Args:
- decay (float): Decay rate for EMA. Must be in the range [0, 1]. Default: 0.999
- Returns:
- Callable: A function that updates EMA parameters given current model parameters
- """
- if decay < 0.0 or decay > 1.0:
- raise ValueError(
- f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
- )
- @torch.no_grad()
- def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
- return decay * ema_param + (1 - decay) * current_param
- return ema_update
- def get_swa_avg_fn():
- """Get the function applying stochastic weight average (SWA) across a single param."""
- @torch.no_grad()
- def swa_update(
- averaged_param: Tensor, current_param: Tensor, num_averaged: Tensor | int
- ):
- return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
- return swa_update
- class AveragedModel(Module):
- r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
- Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
- Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
- Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
- (UAI 2018).
- Exponential Moving Average is a variation of `Polyak averaging`_,
- but using exponential weights instead of equal weights across iterations.
- AveragedModel class creates a copy of the provided module :attr:`model`
- on the device :attr:`device` and allows to compute running averages of the
- parameters of the :attr:`model`.
- Args:
- model (torch.nn.Module): model to use with SWA/EMA
- device (torch.device, optional): if provided, the averaged model will be
- stored on the :attr:`device`
- avg_fn (function, optional): the averaging function used to update
- parameters; the function must take in the current value of the
- :class:`AveragedModel` parameter, the current value of :attr:`model`
- parameter, and the number of models already averaged; if None,
- an equally weighted average is used (default: None)
- multi_avg_fn (function, optional): the averaging function used to update
- parameters inplace; the function must take in the current values of the
- :class:`AveragedModel` parameters as a list, the current values of :attr:`model`
- parameters as a list, and the number of models already averaged; if None,
- an equally weighted average is used (default: None)
- use_buffers (bool): if ``True``, it will compute running averages for
- both the parameters and the buffers of the model. (default: ``False``)
- Example:
- >>> # xdoctest: +SKIP("undefined variables")
- >>> loader, optimizer, model, loss_fn = ...
- >>> swa_model = torch.optim.swa_utils.AveragedModel(model)
- >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
- >>> T_max=300)
- >>> swa_start = 160
- >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
- >>> for i in range(300):
- >>> for input, target in loader:
- >>> optimizer.zero_grad()
- >>> loss_fn(model(input), target).backward()
- >>> optimizer.step()
- >>> if i > swa_start:
- >>> swa_model.update_parameters(model)
- >>> swa_scheduler.step()
- >>> else:
- >>> scheduler.step()
- >>>
- >>> # Update bn statistics for the swa_model at the end
- >>> torch.optim.swa_utils.update_bn(loader, swa_model)
- You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters.
- If no averaging function is provided, the default is to compute
- equally-weighted average of the weights (SWA).
- Example:
- >>> # xdoctest: +SKIP("undefined variables")
- >>> # Compute exponential moving averages of the weights and buffers
- >>> ema_model = torch.optim.swa_utils.AveragedModel(model,
- >>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
- .. note::
- When using SWA/EMA with models containing Batch Normalization you may
- need to update the activation statistics for Batch Normalization.
- This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
- or by setting :attr:`use_buffers` to `True`. The first approach updates the
- statistics in a post-training step by passing data through the model. The
- second does it during the parameter update phase by averaging all buffers.
- Empirical evidence has shown that updating the statistics in normalization
- layers increases accuracy, but you may wish to empirically test which
- approach yields the best results in your problem.
- .. note::
- :attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model.
- .. note::
- When :meth:`update_parameters` is called for the first time (i.e.
- :attr:`n_averaged` is `0`) the parameters of `model` are copied
- to the parameters of :class:`AveragedModel`. For every subsequent
- call of :meth:`update_parameters` the function `avg_fn` is used
- to update the parameters.
- .. _Averaging Weights Leads to Wider Optima and Better Generalization:
- https://arxiv.org/abs/1803.05407
- .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should
- Average:
- https://arxiv.org/abs/1806.05594
- .. _SWALP: Stochastic Weight Averaging in Low-Precision Training:
- https://arxiv.org/abs/1904.11943
- .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That
- Generalizes Well:
- https://arxiv.org/abs/2001.02312
- .. _Polyak averaging:
- https://paperswithcode.com/method/polyak-averaging
- """
- n_averaged: Tensor
- def __init__(
- self,
- model: Module,
- device: int | torch.device | None = None,
- avg_fn: Callable[[Tensor, Tensor, Tensor | int], Tensor] | None = None,
- multi_avg_fn: Callable[[PARAM_LIST, PARAM_LIST, Tensor | int], None]
- | None = None,
- use_buffers=False,
- ) -> None: # noqa: D107
- super().__init__()
- if avg_fn is not None and multi_avg_fn is not None:
- raise AssertionError(
- "Only one of avg_fn and multi_avg_fn should be provided"
- )
- self.module = deepcopy(model)
- if device is not None:
- self.module = self.module.to(device)
- self.register_buffer(
- "n_averaged", torch.tensor(0, dtype=torch.long, device=device)
- )
- self.avg_fn = avg_fn
- self.multi_avg_fn = multi_avg_fn
- self.use_buffers = use_buffers
- def forward(self, *args, **kwargs):
- """Forward pass."""
- return self.module(*args, **kwargs)
- def update_parameters(self, model: Module) -> None:
- """Update model parameters."""
- self_param = (
- # pyrefly: ignore [bad-argument-type]
- itertools.chain(self.module.parameters(), self.module.buffers())
- if self.use_buffers
- else self.parameters()
- )
- model_param = (
- # pyrefly: ignore [bad-argument-type]
- itertools.chain(model.parameters(), model.buffers())
- if self.use_buffers
- else model.parameters()
- )
- self_param_detached: list[Tensor | None] = []
- model_param_detached: list[Tensor | None] = []
- copy_param = bool(self.n_averaged == 0)
- for p_averaged, p_model in zip(self_param, model_param, strict=False):
- p_model_ = p_model.detach().to(p_averaged.device)
- self_param_detached.append(p_averaged.detach())
- model_param_detached.append(p_model_)
- if copy_param:
- p_averaged.detach().copy_(p_model_)
- if self.n_averaged > 0:
- if self.multi_avg_fn is not None or self.avg_fn is None:
- grouped_tensors = _group_tensors_by_device_and_dtype(
- [self_param_detached, model_param_detached]
- )
- for (device, _), (
- [self_params, model_params],
- _,
- ) in grouped_tensors.items():
- if self.multi_avg_fn:
- self.multi_avg_fn(
- self_params, # type: ignore[arg-type]
- model_params, # type: ignore[arg-type]
- self.n_averaged.to(device),
- )
- elif (
- device is not None
- and device.type in _get_foreach_kernels_supported_devices()
- ):
- multi_avg_fn = get_swa_multi_avg_fn()
- multi_avg_fn(
- self_params, model_params, self.n_averaged.to(device)
- )
- else:
- avg_fn = get_swa_avg_fn()
- n_averaged = self.n_averaged.to(device)
- for p_averaged, p_model in zip( # type: ignore[assignment]
- self_params, model_params, strict=True
- ):
- # pyrefly: ignore [missing-attribute]
- p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
- else:
- for p_averaged, p_model in zip( # type: ignore[assignment]
- self_param_detached, model_param_detached, strict=True
- ):
- # pyrefly: ignore [missing-attribute]
- n_averaged = self.n_averaged.to(p_averaged.device)
- # pyrefly: ignore [missing-attribute]
- p_averaged.detach().copy_(
- # pyrefly: ignore [missing-attribute, bad-argument-type]
- self.avg_fn(p_averaged.detach(), p_model, n_averaged)
- )
- if not self.use_buffers:
- # If not apply running averages to the buffers,
- # keep the buffers in sync with the source model.
- for b_swa, b_model in zip(
- self.module.buffers(), model.buffers(), strict=True
- ):
- b_swa.detach().copy_(b_model.detach().to(b_swa.device))
- self.n_averaged += 1
- @torch.no_grad()
- def update_bn(
- loader: Iterable[Any],
- model: Module,
- device: int | torch.device | None = None,
- ) -> None:
- r"""Update BatchNorm running_mean, running_var buffers in the model.
- It performs one pass over data in `loader` to estimate the activation
- statistics for BatchNorm layers in the model.
- Args:
- loader (torch.utils.data.DataLoader): dataset loader to compute the
- activation statistics on. Each data batch should be either a
- tensor, or a list/tuple whose first element is a tensor
- containing data.
- model (torch.nn.Module): model for which we seek to update BatchNorm
- statistics.
- device (torch.device, optional): If set, data will be transferred to
- :attr:`device` before being passed into :attr:`model`.
- Example:
- >>> # xdoctest: +SKIP("Undefined variables")
- >>> loader, model = ...
- >>> torch.optim.swa_utils.update_bn(loader, model)
- .. note::
- The `update_bn` utility assumes that each data batch in :attr:`loader`
- is either a tensor or a list or tuple of tensors; in the latter case it
- is assumed that :meth:`model.forward()` should be called on the first
- element of the list or tuple corresponding to the data batch.
- """
- momenta = {}
- for module in model.modules():
- if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
- module.reset_running_stats()
- momenta[module] = module.momentum
- if not momenta:
- return
- was_training = model.training
- model.train()
- for module in momenta:
- module.momentum = None
- for input in loader:
- if isinstance(input, (list, tuple)):
- input = input[0]
- if device is not None:
- input = input.to(device)
- model(input)
- for bn_module in momenta:
- bn_module.momentum = momenta[bn_module]
- model.train(was_training)
- class SWALR(LRScheduler):
- r"""Anneals the learning rate in each parameter group to a fixed value.
- This learning rate scheduler is meant to be used with Stochastic Weight
- Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`).
- Args:
- optimizer (torch.optim.Optimizer): wrapped optimizer
- swa_lrs (float or list): the learning rate value for all param groups
- together or separately for each group.
- annealing_epochs (int): number of epochs in the annealing phase
- (default: 10)
- annealing_strategy (str): "cos" or "linear"; specifies the annealing
- strategy: "cos" for cosine annealing, "linear" for linear annealing
- (default: "cos")
- last_epoch (int): the index of the last epoch (default: -1)
- The :class:`SWALR` scheduler can be used together with other
- schedulers to switch to a constant learning rate late in the training
- as in the example below.
- Example:
- >>> # xdoctest: +SKIP("Undefined variables")
- >>> loader, optimizer, model = ...
- >>> lr_lambda = lambda epoch: 0.9
- >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
- >>> lr_lambda=lr_lambda)
- >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
- >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
- >>> swa_start = 160
- >>> for i in range(300):
- >>> for input, target in loader:
- >>> optimizer.zero_grad()
- >>> loss_fn(model(input), target).backward()
- >>> optimizer.step()
- >>> if i > swa_start:
- >>> swa_scheduler.step()
- >>> else:
- >>> scheduler.step()
- .. _Averaging Weights Leads to Wider Optima and Better Generalization:
- https://arxiv.org/abs/1803.05407
- """
- def __init__(
- self,
- optimizer: Optimizer,
- swa_lr: float,
- anneal_epochs=10,
- anneal_strategy: Literal["cos", "linear"] = "cos",
- last_epoch=-1,
- ) -> None: # noqa: D107
- swa_lrs = _format_param("swa_lr", optimizer, swa_lr)
- for swa_lr, group in zip(swa_lrs, optimizer.param_groups, strict=True):
- group["swa_lr"] = swa_lr
- if anneal_strategy not in ["cos", "linear"]:
- raise ValueError(
- "anneal_strategy must by one of 'cos' or 'linear', "
- f"instead got {anneal_strategy}"
- )
- self._set_anneal_func(anneal_strategy)
- if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
- raise ValueError(
- f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}"
- )
- self.anneal_epochs = anneal_epochs
- super().__init__(optimizer, last_epoch)
- @staticmethod
- def _linear_anneal(t):
- return t
- @staticmethod
- def _cosine_anneal(t):
- return (1 - math.cos(math.pi * t)) / 2
- @staticmethod
- def _get_initial_lr(lr, swa_lr, alpha):
- if alpha == 1:
- return swa_lr
- return (lr - alpha * swa_lr) / (1 - alpha)
- @override
- def get_lr(self):
- r"""Compute the next learning rate for each of the optimizer's
- :attr:`~torch.optim.Optimizer.param_groups`.
- Uses :attr:`anneal_func` to interpolate between each group's
- ``group["lr"]`` and ``group["swa_lr"]`` over :attr:`anneal_epochs`
- epochs. Once :attr:`anneal_epochs` is reached, keeps the learning rate
- fixed at ``group["swa_lr"]``.
- Returns:
- list[float | Tensor]: A :class:`list` of learning rates for each of
- the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
- same types as their current ``group["lr"]``\s.
- .. note::
- If you're trying to inspect the most recent learning rate, use
- :meth:`get_last_lr()` instead.
- .. note::
- The returned :class:`~torch.Tensor`\s are copies, and never alias
- the optimizer's ``group["lr"]``\s.
- """
- # `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
- # so we ignore the type error here. See `LRScheduler.step()` for more details.
- if not self._get_lr_called_within_step:
- warnings.warn(
- "To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.",
- UserWarning,
- stacklevel=2,
- )
- # Set in `LRScheduler._initial_step()`
- step = self._step_count - 1
- if self.anneal_epochs == 0:
- step = max(1, step)
- # pyrefly: ignore [no-matching-overload]
- prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
- prev_alpha = self.anneal_func(prev_t)
- prev_lrs = [
- self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
- for group in self.optimizer.param_groups
- ]
- # pyrefly: ignore [no-matching-overload]
- t = max(0, min(1, step / max(1, self.anneal_epochs)))
- alpha = self.anneal_func(t)
- return [
- group["swa_lr"] * alpha + lr * (1 - alpha)
- for group, lr in zip(self.optimizer.param_groups, prev_lrs, strict=True)
- ]
- def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]) -> None:
- self._anneal_strategy = anneal_strategy
- if anneal_strategy == "cos":
- self.anneal_func = self._cosine_anneal
- else:
- self.anneal_func = self._linear_anneal
- @override
- def state_dict(self) -> dict[str, Any]:
- """Return the state of the scheduler as a :class:`dict`.
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer or anneal_func.
- """
- return {
- key: value
- for key, value in self.__dict__.items()
- if key not in ("optimizer", "anneal_func")
- }
- @override
- def load_state_dict(self, state_dict: dict[str, Any]) -> None:
- """Load the scheduler's state.
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- self.__dict__.update(state_dict)
- self._set_anneal_func(self._anneal_strategy)
|