| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 |
- """ Exponential Moving Average (EMA) of model updates
- Hacked together by / Copyright 2020 Ross Wightman
- """
- import logging
- from collections import OrderedDict
- from copy import deepcopy
- from typing import Optional
- import torch
- import torch.nn as nn
- _logger = logging.getLogger(__name__)
- class ModelEma:
- """ Model Exponential Moving Average (DEPRECATED)
- Keep a moving average of everything in the model state_dict (parameters and buffers).
- This version is deprecated, it does not work with scripted models. Will be removed eventually.
- This is intended to allow functionality like
- https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
- A smoothed version of the weights is necessary for some training schemes to perform well.
- E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
- RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
- smoothing of weights to match results. Pay attention to the decay constant you are using
- relative to your update count per epoch.
- To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
- disable validation of the EMA weights. Validation will have to be done manually in a separate
- process, or after the training stops converging.
- This class is sensitive where it is initialized in the sequence of model init,
- GPU assignment and distributed training wrappers.
- """
- def __init__(self, model, decay=0.9999, device='', resume=''):
- # make a copy of the model for accumulating moving average of weights
- self.ema = deepcopy(model)
- self.ema.eval()
- self.decay = decay
- self.device = device # perform ema on different device from model if set
- if device:
- self.ema.to(device=device)
- self.ema_has_module = hasattr(self.ema, 'module')
- if resume:
- self._load_checkpoint(resume)
- for p in self.ema.parameters():
- p.requires_grad_(False)
- def _load_checkpoint(self, checkpoint_path):
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
- assert isinstance(checkpoint, dict)
- if 'state_dict_ema' in checkpoint:
- new_state_dict = OrderedDict()
- for k, v in checkpoint['state_dict_ema'].items():
- # ema model may have been wrapped by DataParallel, and need module prefix
- if self.ema_has_module:
- name = 'module.' + k if not k.startswith('module') else k
- else:
- name = k
- new_state_dict[name] = v
- self.ema.load_state_dict(new_state_dict)
- _logger.info("Loaded state_dict_ema")
- else:
- _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
- def update(self, model):
- # correct a mismatch in state dict keys
- needs_module = hasattr(model, 'module') and not self.ema_has_module
- with torch.no_grad():
- msd = model.state_dict()
- for k, ema_v in self.ema.state_dict().items():
- if needs_module:
- k = 'module.' + k
- model_v = msd[k].detach()
- if self.device:
- model_v = model_v.to(device=self.device)
- ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
- class ModelEmaV2(nn.Module):
- """ Model Exponential Moving Average V2
- Keep a moving average of everything in the model state_dict (parameters and buffers).
- V2 of this module is simpler, it does not match params/buffers based on name but simply
- iterates in order. It works with torchscript (JIT of full model).
- This is intended to allow functionality like
- https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
- A smoothed version of the weights is necessary for some training schemes to perform well.
- E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
- RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
- smoothing of weights to match results. Pay attention to the decay constant you are using
- relative to your update count per epoch.
- To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
- disable validation of the EMA weights. Validation will have to be done manually in a separate
- process, or after the training stops converging.
- This class is sensitive where it is initialized in the sequence of model init,
- GPU assignment and distributed training wrappers.
- """
- def __init__(self, model, decay=0.9999, device=None):
- super().__init__()
- # make a copy of the model for accumulating moving average of weights
- self.module = deepcopy(model)
- self.module.eval()
- self.decay = decay
- self.device = device # perform ema on different device from model if set
- if self.device is not None:
- self.module.to(device=device)
- def _update(self, model, update_fn):
- with torch.no_grad():
- for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
- if self.device is not None:
- model_v = model_v.to(device=self.device)
- ema_v.copy_(update_fn(ema_v, model_v))
- def update(self, model):
- self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
- def set(self, model):
- self._update(model, update_fn=lambda e, m: m)
- def forward(self, *args, **kwargs):
- return self.module(*args, **kwargs)
- class ModelEmaV3(nn.Module):
- """ Model Exponential Moving Average V3
- Keep a moving average of everything in the model state_dict (parameters and buffers).
- V3 of this module leverages for_each and in-place operations for faster performance.
- Decay warmup based on code by @crowsonkb, her comments:
- If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
- good values for models you plan to train for a million or more steps (reaches decay
- factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
- you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
- 215.4k steps).
- This is intended to allow functionality like
- https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
- To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
- disable validation of the EMA weights. Validation will have to be done manually in a separate
- process, or after the training stops converging.
- This class is sensitive where it is initialized in the sequence of model init,
- GPU assignment and distributed training wrappers.
- """
- def __init__(
- self,
- model,
- decay: float = 0.9999,
- min_decay: float = 0.0,
- update_after_step: int = 0,
- use_warmup: bool = False,
- warmup_gamma: float = 1.0,
- warmup_power: float = 2/3,
- device: Optional[torch.device] = None,
- foreach: bool = True,
- exclude_buffers: bool = False,
- ):
- super().__init__()
- # make a copy of the model for accumulating moving average of weights
- self.module = deepcopy(model)
- self.module.eval()
- self.decay = decay
- self.min_decay = min_decay
- self.update_after_step = update_after_step
- self.use_warmup = use_warmup
- self.warmup_gamma = warmup_gamma
- self.warmup_power = warmup_power
- self.foreach = foreach
- self.device = device # perform ema on different device from model if set
- self.exclude_buffers = exclude_buffers
- if self.device is not None and device != next(model.parameters()).device:
- self.foreach = False # cannot use foreach methods with different devices
- self.module.to(device=device)
- def get_decay(self, step: Optional[int] = None) -> float:
- """
- Compute the decay factor for the exponential moving average.
- """
- if step is None:
- return self.decay
- step = max(0, step - self.update_after_step - 1)
- if step <= 0:
- return 0.0
- if self.use_warmup:
- decay = 1 - (1 + step / self.warmup_gamma) ** -self.warmup_power
- decay = max(min(decay, self.decay), self.min_decay)
- else:
- decay = self.decay
- return decay
- @torch.no_grad()
- def update(self, model, step: Optional[int] = None):
- decay = self.get_decay(step)
- if self.exclude_buffers:
- self.apply_update_no_buffers_(model, decay)
- else:
- self.apply_update_(model, decay)
- def apply_update_(self, model, decay: float):
- # interpolate parameters and buffers
- if self.foreach:
- ema_lerp_values = []
- model_lerp_values = []
- for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
- if ema_v.is_floating_point():
- ema_lerp_values.append(ema_v)
- model_lerp_values.append(model_v)
- else:
- ema_v.copy_(model_v)
- if hasattr(torch, '_foreach_lerp_'):
- torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay)
- else:
- torch._foreach_mul_(ema_lerp_values, scalar=decay)
- torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay)
- else:
- for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
- if ema_v.is_floating_point():
- ema_v.lerp_(model_v.to(device=self.device), weight=1. - decay)
- else:
- ema_v.copy_(model_v.to(device=self.device))
- def apply_update_no_buffers_(self, model, decay: float):
- # interpolate parameters, copy buffers
- ema_params = tuple(self.module.parameters())
- model_params = tuple(model.parameters())
- if self.foreach:
- if hasattr(torch, '_foreach_lerp_'):
- torch._foreach_lerp_(ema_params, model_params, weight=1. - decay)
- else:
- torch._foreach_mul_(ema_params, scalar=decay)
- torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
- else:
- for ema_p, model_p in zip(ema_params, model_params):
- ema_p.lerp_(model_p.to(device=self.device), weight=1. - decay)
- for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
- ema_b.copy_(model_b.to(device=self.device))
- @torch.no_grad()
- def set(self, model):
- for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
- ema_v.copy_(model_v.to(device=self.device))
- def forward(self, *args, **kwargs):
- return self.module(*args, **kwargs)
|