model_ema.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. """ Exponential Moving Average (EMA) of model updates
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import logging
  5. from collections import OrderedDict
  6. from copy import deepcopy
  7. from typing import Optional
  8. import torch
  9. import torch.nn as nn
  10. _logger = logging.getLogger(__name__)
  11. class ModelEma:
  12. """ Model Exponential Moving Average (DEPRECATED)
  13. Keep a moving average of everything in the model state_dict (parameters and buffers).
  14. This version is deprecated, it does not work with scripted models. Will be removed eventually.
  15. This is intended to allow functionality like
  16. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  17. A smoothed version of the weights is necessary for some training schemes to perform well.
  18. E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
  19. RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
  20. smoothing of weights to match results. Pay attention to the decay constant you are using
  21. relative to your update count per epoch.
  22. To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
  23. disable validation of the EMA weights. Validation will have to be done manually in a separate
  24. process, or after the training stops converging.
  25. This class is sensitive where it is initialized in the sequence of model init,
  26. GPU assignment and distributed training wrappers.
  27. """
  28. def __init__(self, model, decay=0.9999, device='', resume=''):
  29. # make a copy of the model for accumulating moving average of weights
  30. self.ema = deepcopy(model)
  31. self.ema.eval()
  32. self.decay = decay
  33. self.device = device # perform ema on different device from model if set
  34. if device:
  35. self.ema.to(device=device)
  36. self.ema_has_module = hasattr(self.ema, 'module')
  37. if resume:
  38. self._load_checkpoint(resume)
  39. for p in self.ema.parameters():
  40. p.requires_grad_(False)
  41. def _load_checkpoint(self, checkpoint_path):
  42. checkpoint = torch.load(checkpoint_path, map_location='cpu')
  43. assert isinstance(checkpoint, dict)
  44. if 'state_dict_ema' in checkpoint:
  45. new_state_dict = OrderedDict()
  46. for k, v in checkpoint['state_dict_ema'].items():
  47. # ema model may have been wrapped by DataParallel, and need module prefix
  48. if self.ema_has_module:
  49. name = 'module.' + k if not k.startswith('module') else k
  50. else:
  51. name = k
  52. new_state_dict[name] = v
  53. self.ema.load_state_dict(new_state_dict)
  54. _logger.info("Loaded state_dict_ema")
  55. else:
  56. _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
  57. def update(self, model):
  58. # correct a mismatch in state dict keys
  59. needs_module = hasattr(model, 'module') and not self.ema_has_module
  60. with torch.no_grad():
  61. msd = model.state_dict()
  62. for k, ema_v in self.ema.state_dict().items():
  63. if needs_module:
  64. k = 'module.' + k
  65. model_v = msd[k].detach()
  66. if self.device:
  67. model_v = model_v.to(device=self.device)
  68. ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
  69. class ModelEmaV2(nn.Module):
  70. """ Model Exponential Moving Average V2
  71. Keep a moving average of everything in the model state_dict (parameters and buffers).
  72. V2 of this module is simpler, it does not match params/buffers based on name but simply
  73. iterates in order. It works with torchscript (JIT of full model).
  74. This is intended to allow functionality like
  75. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  76. A smoothed version of the weights is necessary for some training schemes to perform well.
  77. E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
  78. RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
  79. smoothing of weights to match results. Pay attention to the decay constant you are using
  80. relative to your update count per epoch.
  81. To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
  82. disable validation of the EMA weights. Validation will have to be done manually in a separate
  83. process, or after the training stops converging.
  84. This class is sensitive where it is initialized in the sequence of model init,
  85. GPU assignment and distributed training wrappers.
  86. """
  87. def __init__(self, model, decay=0.9999, device=None):
  88. super().__init__()
  89. # make a copy of the model for accumulating moving average of weights
  90. self.module = deepcopy(model)
  91. self.module.eval()
  92. self.decay = decay
  93. self.device = device # perform ema on different device from model if set
  94. if self.device is not None:
  95. self.module.to(device=device)
  96. def _update(self, model, update_fn):
  97. with torch.no_grad():
  98. for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
  99. if self.device is not None:
  100. model_v = model_v.to(device=self.device)
  101. ema_v.copy_(update_fn(ema_v, model_v))
  102. def update(self, model):
  103. self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
  104. def set(self, model):
  105. self._update(model, update_fn=lambda e, m: m)
  106. def forward(self, *args, **kwargs):
  107. return self.module(*args, **kwargs)
  108. class ModelEmaV3(nn.Module):
  109. """ Model Exponential Moving Average V3
  110. Keep a moving average of everything in the model state_dict (parameters and buffers).
  111. V3 of this module leverages for_each and in-place operations for faster performance.
  112. Decay warmup based on code by @crowsonkb, her comments:
  113. If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
  114. good values for models you plan to train for a million or more steps (reaches decay
  115. factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
  116. you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
  117. 215.4k steps).
  118. This is intended to allow functionality like
  119. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  120. To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
  121. disable validation of the EMA weights. Validation will have to be done manually in a separate
  122. process, or after the training stops converging.
  123. This class is sensitive where it is initialized in the sequence of model init,
  124. GPU assignment and distributed training wrappers.
  125. """
  126. def __init__(
  127. self,
  128. model,
  129. decay: float = 0.9999,
  130. min_decay: float = 0.0,
  131. update_after_step: int = 0,
  132. use_warmup: bool = False,
  133. warmup_gamma: float = 1.0,
  134. warmup_power: float = 2/3,
  135. device: Optional[torch.device] = None,
  136. foreach: bool = True,
  137. exclude_buffers: bool = False,
  138. ):
  139. super().__init__()
  140. # make a copy of the model for accumulating moving average of weights
  141. self.module = deepcopy(model)
  142. self.module.eval()
  143. self.decay = decay
  144. self.min_decay = min_decay
  145. self.update_after_step = update_after_step
  146. self.use_warmup = use_warmup
  147. self.warmup_gamma = warmup_gamma
  148. self.warmup_power = warmup_power
  149. self.foreach = foreach
  150. self.device = device # perform ema on different device from model if set
  151. self.exclude_buffers = exclude_buffers
  152. if self.device is not None and device != next(model.parameters()).device:
  153. self.foreach = False # cannot use foreach methods with different devices
  154. self.module.to(device=device)
  155. def get_decay(self, step: Optional[int] = None) -> float:
  156. """
  157. Compute the decay factor for the exponential moving average.
  158. """
  159. if step is None:
  160. return self.decay
  161. step = max(0, step - self.update_after_step - 1)
  162. if step <= 0:
  163. return 0.0
  164. if self.use_warmup:
  165. decay = 1 - (1 + step / self.warmup_gamma) ** -self.warmup_power
  166. decay = max(min(decay, self.decay), self.min_decay)
  167. else:
  168. decay = self.decay
  169. return decay
  170. @torch.no_grad()
  171. def update(self, model, step: Optional[int] = None):
  172. decay = self.get_decay(step)
  173. if self.exclude_buffers:
  174. self.apply_update_no_buffers_(model, decay)
  175. else:
  176. self.apply_update_(model, decay)
  177. def apply_update_(self, model, decay: float):
  178. # interpolate parameters and buffers
  179. if self.foreach:
  180. ema_lerp_values = []
  181. model_lerp_values = []
  182. for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
  183. if ema_v.is_floating_point():
  184. ema_lerp_values.append(ema_v)
  185. model_lerp_values.append(model_v)
  186. else:
  187. ema_v.copy_(model_v)
  188. if hasattr(torch, '_foreach_lerp_'):
  189. torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay)
  190. else:
  191. torch._foreach_mul_(ema_lerp_values, scalar=decay)
  192. torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay)
  193. else:
  194. for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
  195. if ema_v.is_floating_point():
  196. ema_v.lerp_(model_v.to(device=self.device), weight=1. - decay)
  197. else:
  198. ema_v.copy_(model_v.to(device=self.device))
  199. def apply_update_no_buffers_(self, model, decay: float):
  200. # interpolate parameters, copy buffers
  201. ema_params = tuple(self.module.parameters())
  202. model_params = tuple(model.parameters())
  203. if self.foreach:
  204. if hasattr(torch, '_foreach_lerp_'):
  205. torch._foreach_lerp_(ema_params, model_params, weight=1. - decay)
  206. else:
  207. torch._foreach_mul_(ema_params, scalar=decay)
  208. torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
  209. else:
  210. for ema_p, model_p in zip(ema_params, model_params):
  211. ema_p.lerp_(model_p.to(device=self.device), weight=1. - decay)
  212. for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
  213. ema_b.copy_(model_b.to(device=self.device))
  214. @torch.no_grad()
  215. def set(self, model):
  216. for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
  217. ema_v.copy_(model_v.to(device=self.device))
  218. def forward(self, *args, **kwargs):
  219. return self.module(*args, **kwargs)