batchnorm.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905
  1. # mypy: allow-untyped-defs
  2. from typing import Any
  3. import torch
  4. from torch import Tensor
  5. from torch.nn import functional as F, init
  6. from torch.nn.parameter import Parameter, UninitializedBuffer, UninitializedParameter
  7. from ._functions import SyncBatchNorm as sync_batch_norm
  8. from .lazy import LazyModuleMixin
  9. from .module import Module
  10. __all__ = [
  11. "BatchNorm1d",
  12. "LazyBatchNorm1d",
  13. "BatchNorm2d",
  14. "LazyBatchNorm2d",
  15. "BatchNorm3d",
  16. "LazyBatchNorm3d",
  17. "SyncBatchNorm",
  18. ]
  19. class _NormBase(Module):
  20. """Common base of _InstanceNorm and _BatchNorm."""
  21. _version = 2
  22. __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
  23. num_features: int
  24. eps: float
  25. momentum: float | None
  26. affine: bool
  27. track_running_stats: bool
  28. # WARNING: weight and bias purposely not defined here.
  29. # See https://github.com/pytorch/pytorch/issues/39670
  30. def __init__(
  31. self,
  32. num_features: int,
  33. eps: float = 1e-5,
  34. momentum: float | None = 0.1,
  35. affine: bool = True,
  36. track_running_stats: bool = True,
  37. device=None,
  38. dtype=None,
  39. ) -> None:
  40. factory_kwargs = {"device": device, "dtype": dtype}
  41. super().__init__()
  42. self.num_features = num_features
  43. self.eps = eps
  44. self.momentum = momentum
  45. self.affine = affine
  46. self.track_running_stats = track_running_stats
  47. if self.affine:
  48. self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
  49. self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
  50. else:
  51. self.register_parameter("weight", None)
  52. self.register_parameter("bias", None)
  53. if self.track_running_stats:
  54. self.register_buffer(
  55. "running_mean", torch.zeros(num_features, **factory_kwargs)
  56. )
  57. self.register_buffer(
  58. "running_var", torch.ones(num_features, **factory_kwargs)
  59. )
  60. self.running_mean: Tensor | None
  61. self.running_var: Tensor | None
  62. self.register_buffer(
  63. "num_batches_tracked",
  64. torch.tensor(
  65. 0,
  66. dtype=torch.long,
  67. # pyrefly: ignore [bad-argument-type]
  68. **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
  69. ),
  70. )
  71. self.num_batches_tracked: Tensor | None
  72. else:
  73. self.register_buffer("running_mean", None)
  74. self.register_buffer("running_var", None)
  75. self.register_buffer("num_batches_tracked", None)
  76. self.reset_parameters()
  77. def reset_running_stats(self) -> None:
  78. if self.track_running_stats:
  79. # running_mean/running_var/num_batches... are registered at runtime depending
  80. # if self.track_running_stats is on
  81. self.running_mean.zero_() # type: ignore[union-attr]
  82. self.running_var.fill_(1) # type: ignore[union-attr]
  83. self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
  84. def reset_parameters(self) -> None:
  85. self.reset_running_stats()
  86. if self.affine:
  87. init.ones_(self.weight)
  88. init.zeros_(self.bias)
  89. def _check_input_dim(self, input):
  90. raise NotImplementedError
  91. def extra_repr(self):
  92. return (
  93. "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
  94. "track_running_stats={track_running_stats}".format(**self.__dict__)
  95. )
  96. def _load_from_state_dict(
  97. self,
  98. state_dict,
  99. prefix,
  100. local_metadata,
  101. strict,
  102. missing_keys,
  103. unexpected_keys,
  104. error_msgs,
  105. ) -> None:
  106. version = local_metadata.get("version", None)
  107. if (version is None or version < 2) and self.track_running_stats:
  108. # at version 2: added num_batches_tracked buffer
  109. # this should have a default value of 0
  110. num_batches_tracked_key = prefix + "num_batches_tracked"
  111. if num_batches_tracked_key not in state_dict:
  112. state_dict[num_batches_tracked_key] = (
  113. self.num_batches_tracked
  114. if self.num_batches_tracked is not None
  115. and self.num_batches_tracked.device != torch.device("meta")
  116. else torch.tensor(0, dtype=torch.long)
  117. )
  118. super()._load_from_state_dict(
  119. state_dict,
  120. prefix,
  121. local_metadata,
  122. strict,
  123. missing_keys,
  124. unexpected_keys,
  125. error_msgs,
  126. )
  127. class _BatchNorm(_NormBase):
  128. def __init__(
  129. self,
  130. num_features: int,
  131. eps: float = 1e-5,
  132. momentum: float | None = 0.1,
  133. affine: bool = True,
  134. track_running_stats: bool = True,
  135. device=None,
  136. dtype=None,
  137. ) -> None:
  138. factory_kwargs = {"device": device, "dtype": dtype}
  139. super().__init__(
  140. num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
  141. )
  142. def forward(self, input: Tensor) -> Tensor:
  143. self._check_input_dim(input)
  144. # exponential_average_factor is set to self.momentum
  145. # (when it is available) only so that it gets updated
  146. # in ONNX graph when this node is exported to ONNX.
  147. if self.momentum is None:
  148. exponential_average_factor = 0.0
  149. else:
  150. exponential_average_factor = self.momentum
  151. if self.training and self.track_running_stats:
  152. # TODO: if statement only here to tell the jit to skip emitting this when it is None
  153. if self.num_batches_tracked is not None: # type: ignore[has-type]
  154. self.num_batches_tracked.add_(1) # type: ignore[has-type]
  155. if self.momentum is None: # use cumulative moving average
  156. exponential_average_factor = 1.0 / float(self.num_batches_tracked)
  157. else: # use exponential moving average
  158. exponential_average_factor = self.momentum
  159. r"""
  160. Decide whether the mini-batch stats should be used for normalization rather than the buffers.
  161. Mini-batch stats are used in training mode, and in eval mode when buffers are None.
  162. """
  163. if self.training:
  164. bn_training = True
  165. else:
  166. bn_training = (self.running_mean is None) and (self.running_var is None)
  167. r"""
  168. Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
  169. passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
  170. used for normalization (i.e. in eval mode when buffers are not None).
  171. """
  172. return F.batch_norm(
  173. input,
  174. # If buffers are not to be tracked, ensure that they won't be updated
  175. (
  176. self.running_mean
  177. if not self.training or self.track_running_stats
  178. else None
  179. ),
  180. self.running_var if not self.training or self.track_running_stats else None,
  181. self.weight,
  182. self.bias,
  183. bn_training,
  184. exponential_average_factor,
  185. self.eps,
  186. )
  187. class _LazyNormBase(LazyModuleMixin, _NormBase):
  188. weight: UninitializedParameter # type: ignore[assignment]
  189. bias: UninitializedParameter # type: ignore[assignment]
  190. def __init__(
  191. self,
  192. eps=1e-5,
  193. momentum=0.1,
  194. affine=True,
  195. track_running_stats=True,
  196. device=None,
  197. dtype=None,
  198. ) -> None:
  199. factory_kwargs = {"device": device, "dtype": dtype}
  200. # pyrefly: ignore [bad-argument-type]
  201. super().__init__(
  202. # affine and track_running_stats are hardcoded to False to
  203. # avoid creating tensors that will soon be overwritten.
  204. 0,
  205. eps,
  206. momentum,
  207. False,
  208. False,
  209. **factory_kwargs,
  210. )
  211. self.affine = affine
  212. self.track_running_stats = track_running_stats
  213. if self.affine:
  214. # pyrefly: ignore [unexpected-keyword]
  215. self.weight = UninitializedParameter(**factory_kwargs)
  216. # pyrefly: ignore [unexpected-keyword]
  217. self.bias = UninitializedParameter(**factory_kwargs)
  218. if self.track_running_stats:
  219. # pyrefly: ignore [unexpected-keyword]
  220. self.running_mean = UninitializedBuffer(**factory_kwargs)
  221. # pyrefly: ignore [unexpected-keyword]
  222. self.running_var = UninitializedBuffer(**factory_kwargs)
  223. self.num_batches_tracked = torch.tensor(
  224. 0,
  225. dtype=torch.long,
  226. # pyrefly: ignore [bad-argument-type]
  227. **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
  228. )
  229. def reset_parameters(self) -> None:
  230. # pyrefly: ignore [bad-argument-type]
  231. if not self.has_uninitialized_params() and self.num_features != 0:
  232. super().reset_parameters()
  233. def initialize_parameters(self, input) -> None: # type: ignore[override]
  234. # pyrefly: ignore [bad-argument-type]
  235. if self.has_uninitialized_params():
  236. self.num_features = input.shape[1]
  237. if self.affine:
  238. if not isinstance(self.weight, UninitializedParameter):
  239. raise AssertionError(
  240. "self.weight must be an UninitializedParameter"
  241. )
  242. if not isinstance(self.bias, UninitializedParameter):
  243. raise AssertionError("self.bias must be an UninitializedParameter")
  244. self.weight.materialize((self.num_features,))
  245. self.bias.materialize((self.num_features,))
  246. if self.track_running_stats:
  247. self.running_mean.materialize( # type:ignore[union-attr]
  248. (self.num_features,)
  249. )
  250. self.running_var.materialize( # type:ignore[union-attr]
  251. (self.num_features,)
  252. )
  253. self.reset_parameters()
  254. class BatchNorm1d(_BatchNorm):
  255. r"""Applies Batch Normalization over a 2D or 3D input.
  256. Method described in the paper
  257. `Batch Normalization: Accelerating Deep Network Training by Reducing
  258. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  259. .. math::
  260. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  261. The mean and standard-deviation are calculated per-dimension over
  262. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  263. of size `C` (where `C` is the number of features or channels of the input). By default, the
  264. elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
  265. At train time in the forward pass, the variance is calculated via the biased estimator,
  266. equivalent to ``torch.var(input, correction=0)``. However, the value stored in the
  267. moving average of the variance is calculated via the unbiased estimator, equivalent to
  268. ``torch.var(input, correction=1)``.
  269. Also by default, during training this layer keeps running estimates of its
  270. computed mean and variance, which are then used for normalization during
  271. evaluation. The running estimates are kept with a default :attr:`momentum`
  272. of 0.1.
  273. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  274. keep running estimates, and batch statistics are instead used during
  275. evaluation time as well.
  276. .. note::
  277. This :attr:`momentum` argument is different from one used in optimizer
  278. classes and the conventional notion of momentum. Mathematically, the
  279. update rule for running statistics here is
  280. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  281. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  282. new observed value.
  283. Because the Batch Normalization is done over the `C` dimension, computing statistics
  284. on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
  285. Args:
  286. num_features: number of features or channels :math:`C` of the input
  287. eps: a value added to the denominator for numerical stability.
  288. Default: 1e-5
  289. momentum: the value used for the running_mean and running_var
  290. computation. Can be set to ``None`` for cumulative moving average
  291. (i.e. simple average). Default: 0.1
  292. affine: a boolean value that when set to ``True``, this module has
  293. learnable affine parameters. Default: ``True``
  294. track_running_stats: a boolean value that when set to ``True``, this
  295. module tracks the running mean and variance, and when set to ``False``,
  296. this module does not track such statistics, and initializes statistics
  297. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  298. When these buffers are ``None``, this module always uses batch statistics.
  299. in both training and eval modes. Default: ``True``
  300. Shape:
  301. - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
  302. :math:`C` is the number of features or channels, and :math:`L` is the sequence length
  303. - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
  304. Examples::
  305. >>> # With Learnable Parameters
  306. >>> m = nn.BatchNorm1d(100)
  307. >>> # Without Learnable Parameters
  308. >>> m = nn.BatchNorm1d(100, affine=False)
  309. >>> input = torch.randn(20, 100)
  310. >>> output = m(input)
  311. """
  312. def _check_input_dim(self, input) -> None:
  313. if input.dim() != 2 and input.dim() != 3:
  314. raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
  315. class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
  316. r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
  317. Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
  318. from the ``input.size(1)``.
  319. The attributes that will be lazily initialized are `weight`, `bias`,
  320. `running_mean` and `running_var`.
  321. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  322. on lazy modules and their limitations.
  323. Args:
  324. eps: a value added to the denominator for numerical stability.
  325. Default: 1e-5
  326. momentum: the value used for the running_mean and running_var
  327. computation. Can be set to ``None`` for cumulative moving average
  328. (i.e. simple average). Default: 0.1
  329. affine: a boolean value that when set to ``True``, this module has
  330. learnable affine parameters. Default: ``True``
  331. track_running_stats: a boolean value that when set to ``True``, this
  332. module tracks the running mean and variance, and when set to ``False``,
  333. this module does not track such statistics, and initializes statistics
  334. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  335. When these buffers are ``None``, this module always uses batch statistics.
  336. in both training and eval modes. Default: ``True``
  337. """
  338. cls_to_become = BatchNorm1d # type: ignore[assignment]
  339. def _check_input_dim(self, input) -> None:
  340. if input.dim() != 2 and input.dim() != 3:
  341. raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
  342. class BatchNorm2d(_BatchNorm):
  343. r"""Applies Batch Normalization over a 4D input.
  344. 4D is a mini-batch of 2D inputs
  345. with additional channel dimension. Method described in the paper
  346. `Batch Normalization: Accelerating Deep Network Training by Reducing
  347. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  348. .. math::
  349. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  350. The mean and standard-deviation are calculated per-dimension over
  351. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  352. of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
  353. to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
  354. standard-deviation is calculated via the biased estimator, equivalent to
  355. ``torch.var(input, correction=0)``. However, the value stored in the moving average of the
  356. standard-deviation is calculated via the unbiased estimator, equivalent to
  357. ``torch.var(input, correction=1)``.
  358. Also by default, during training this layer keeps running estimates of its
  359. computed mean and variance, which are then used for normalization during
  360. evaluation. The running estimates are kept with a default :attr:`momentum`
  361. of 0.1.
  362. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  363. keep running estimates, and batch statistics are instead used during
  364. evaluation time as well.
  365. .. note::
  366. This :attr:`momentum` argument is different from one used in optimizer
  367. classes and the conventional notion of momentum. Mathematically, the
  368. update rule for running statistics here is
  369. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  370. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  371. new observed value.
  372. Because the Batch Normalization is done over the `C` dimension, computing statistics
  373. on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
  374. Args:
  375. num_features: :math:`C` from an expected input of size
  376. :math:`(N, C, H, W)`
  377. eps: a value added to the denominator for numerical stability.
  378. Default: 1e-5
  379. momentum: the value used for the running_mean and running_var
  380. computation. Can be set to ``None`` for cumulative moving average
  381. (i.e. simple average). Default: 0.1
  382. affine: a boolean value that when set to ``True``, this module has
  383. learnable affine parameters. Default: ``True``
  384. track_running_stats: a boolean value that when set to ``True``, this
  385. module tracks the running mean and variance, and when set to ``False``,
  386. this module does not track such statistics, and initializes statistics
  387. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  388. When these buffers are ``None``, this module always uses batch statistics.
  389. in both training and eval modes. Default: ``True``
  390. Shape:
  391. - Input: :math:`(N, C, H, W)`
  392. - Output: :math:`(N, C, H, W)` (same shape as input)
  393. Examples::
  394. >>> # With Learnable Parameters
  395. >>> m = nn.BatchNorm2d(100)
  396. >>> # Without Learnable Parameters
  397. >>> m = nn.BatchNorm2d(100, affine=False)
  398. >>> input = torch.randn(20, 100, 35, 45)
  399. >>> output = m(input)
  400. """
  401. def _check_input_dim(self, input) -> None:
  402. if input.dim() != 4:
  403. raise ValueError(f"expected 4D input (got {input.dim()}D input)")
  404. class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
  405. r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
  406. Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
  407. from the ``input.size(1)``.
  408. The attributes that will be lazily initialized are `weight`, `bias`,
  409. `running_mean` and `running_var`.
  410. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  411. on lazy modules and their limitations.
  412. Args:
  413. eps: a value added to the denominator for numerical stability.
  414. Default: 1e-5
  415. momentum: the value used for the running_mean and running_var
  416. computation. Can be set to ``None`` for cumulative moving average
  417. (i.e. simple average). Default: 0.1
  418. affine: a boolean value that when set to ``True``, this module has
  419. learnable affine parameters. Default: ``True``
  420. track_running_stats: a boolean value that when set to ``True``, this
  421. module tracks the running mean and variance, and when set to ``False``,
  422. this module does not track such statistics, and initializes statistics
  423. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  424. When these buffers are ``None``, this module always uses batch statistics.
  425. in both training and eval modes. Default: ``True``
  426. """
  427. cls_to_become = BatchNorm2d # type: ignore[assignment]
  428. def _check_input_dim(self, input) -> None:
  429. if input.dim() != 4:
  430. raise ValueError(f"expected 4D input (got {input.dim()}D input)")
  431. class BatchNorm3d(_BatchNorm):
  432. r"""Applies Batch Normalization over a 5D input.
  433. 5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper
  434. `Batch Normalization: Accelerating Deep Network Training by Reducing
  435. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  436. .. math::
  437. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  438. The mean and standard-deviation are calculated per-dimension over
  439. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  440. of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
  441. to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
  442. standard-deviation is calculated via the biased estimator, equivalent to
  443. ``torch.var(input, correction=0)``. However, the value stored in the moving average of the
  444. standard-deviation is calculated via the unbiased estimator, equivalent to
  445. ``torch.var(input, correction=1)``.
  446. Also by default, during training this layer keeps running estimates of its
  447. computed mean and variance, which are then used for normalization during
  448. evaluation. The running estimates are kept with a default :attr:`momentum`
  449. of 0.1.
  450. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  451. keep running estimates, and batch statistics are instead used during
  452. evaluation time as well.
  453. .. note::
  454. This :attr:`momentum` argument is different from one used in optimizer
  455. classes and the conventional notion of momentum. Mathematically, the
  456. update rule for running statistics here is
  457. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  458. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  459. new observed value.
  460. Because the Batch Normalization is done over the `C` dimension, computing statistics
  461. on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
  462. or Spatio-temporal Batch Normalization.
  463. Args:
  464. num_features: :math:`C` from an expected input of size
  465. :math:`(N, C, D, H, W)`
  466. eps: a value added to the denominator for numerical stability.
  467. Default: 1e-5
  468. momentum: the value used for the running_mean and running_var
  469. computation. Can be set to ``None`` for cumulative moving average
  470. (i.e. simple average). Default: 0.1
  471. affine: a boolean value that when set to ``True``, this module has
  472. learnable affine parameters. Default: ``True``
  473. track_running_stats: a boolean value that when set to ``True``, this
  474. module tracks the running mean and variance, and when set to ``False``,
  475. this module does not track such statistics, and initializes statistics
  476. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  477. When these buffers are ``None``, this module always uses batch statistics.
  478. in both training and eval modes. Default: ``True``
  479. Shape:
  480. - Input: :math:`(N, C, D, H, W)`
  481. - Output: :math:`(N, C, D, H, W)` (same shape as input)
  482. Examples::
  483. >>> # With Learnable Parameters
  484. >>> m = nn.BatchNorm3d(100)
  485. >>> # Without Learnable Parameters
  486. >>> m = nn.BatchNorm3d(100, affine=False)
  487. >>> input = torch.randn(20, 100, 35, 45, 10)
  488. >>> output = m(input)
  489. """
  490. def _check_input_dim(self, input) -> None:
  491. if input.dim() != 5:
  492. raise ValueError(f"expected 5D input (got {input.dim()}D input)")
  493. class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
  494. r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
  495. Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
  496. from the ``input.size(1)``.
  497. The attributes that will be lazily initialized are `weight`, `bias`,
  498. `running_mean` and `running_var`.
  499. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  500. on lazy modules and their limitations.
  501. Args:
  502. eps: a value added to the denominator for numerical stability.
  503. Default: 1e-5
  504. momentum: the value used for the running_mean and running_var
  505. computation. Can be set to ``None`` for cumulative moving average
  506. (i.e. simple average). Default: 0.1
  507. affine: a boolean value that when set to ``True``, this module has
  508. learnable affine parameters. Default: ``True``
  509. track_running_stats: a boolean value that when set to ``True``, this
  510. module tracks the running mean and variance, and when set to ``False``,
  511. this module does not track such statistics, and initializes statistics
  512. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  513. When these buffers are ``None``, this module always uses batch statistics.
  514. in both training and eval modes. Default: ``True``
  515. """
  516. cls_to_become = BatchNorm3d # type: ignore[assignment]
  517. def _check_input_dim(self, input) -> None:
  518. if input.dim() != 5:
  519. raise ValueError(f"expected 5D input (got {input.dim()}D input)")
  520. class SyncBatchNorm(_BatchNorm):
  521. r"""Applies Batch Normalization over a N-Dimensional input.
  522. The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper
  523. `Batch Normalization: Accelerating Deep Network Training by Reducing
  524. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  525. .. math::
  526. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  527. The mean and standard-deviation are calculated per-dimension over all
  528. mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
  529. are learnable parameter vectors of size `C` (where `C` is the input size).
  530. By default, the elements of :math:`\gamma` are sampled from
  531. :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
  532. The standard-deviation is calculated via the biased estimator, equivalent to
  533. `torch.var(input, correction=0)`.
  534. Also by default, during training this layer keeps running estimates of its
  535. computed mean and variance, which are then used for normalization during
  536. evaluation. The running estimates are kept with a default :attr:`momentum`
  537. of 0.1.
  538. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  539. keep running estimates, and batch statistics are instead used during
  540. evaluation time as well.
  541. .. note::
  542. This :attr:`momentum` argument is different from one used in optimizer
  543. classes and the conventional notion of momentum. Mathematically, the
  544. update rule for running statistics here is
  545. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  546. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  547. new observed value.
  548. Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
  549. statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
  550. Normalization or Spatio-temporal Batch Normalization.
  551. Currently :class:`SyncBatchNorm` only supports
  552. :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
  553. :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
  554. :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
  555. Network with DDP.
  556. Args:
  557. num_features: :math:`C` from an expected input of size
  558. :math:`(N, C, +)`
  559. eps: a value added to the denominator for numerical stability.
  560. Default: ``1e-5``
  561. momentum: the value used for the running_mean and running_var
  562. computation. Can be set to ``None`` for cumulative moving average
  563. (i.e. simple average). Default: 0.1
  564. affine: a boolean value that when set to ``True``, this module has
  565. learnable affine parameters. Default: ``True``
  566. track_running_stats: a boolean value that when set to ``True``, this
  567. module tracks the running mean and variance, and when set to ``False``,
  568. this module does not track such statistics, and initializes statistics
  569. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  570. When these buffers are ``None``, this module always uses batch statistics.
  571. in both training and eval modes. Default: ``True``
  572. process_group: synchronization of stats happen within each process group
  573. individually. Default behavior is synchronization across the whole
  574. world
  575. Shape:
  576. - Input: :math:`(N, C, +)`
  577. - Output: :math:`(N, C, +)` (same shape as input)
  578. .. note::
  579. Synchronization of batchnorm statistics occurs only while training, i.e.
  580. synchronization is disabled when ``model.eval()`` is set or if
  581. ``self.training`` is otherwise ``False``.
  582. Examples::
  583. >>> # xdoctest: +SKIP
  584. >>> # With Learnable Parameters
  585. >>> m = nn.SyncBatchNorm(100)
  586. >>> # creating process group (optional)
  587. >>> # ranks is a list of int identifying rank ids.
  588. >>> ranks = list(range(8))
  589. >>> r1, r2 = ranks[:4], ranks[4:]
  590. >>> # Note: every rank calls into new_group for every
  591. >>> # process group created, even if that rank is not
  592. >>> # part of the group.
  593. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
  594. >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
  595. >>> # Without Learnable Parameters
  596. >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
  597. >>> input = torch.randn(20, 100, 35, 45, 10)
  598. >>> output = m(input)
  599. >>> # network is nn.BatchNorm layer
  600. >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
  601. >>> # only single gpu per process is currently supported
  602. >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
  603. >>> sync_bn_network,
  604. >>> device_ids=[args.local_rank],
  605. >>> output_device=args.local_rank)
  606. """
  607. def __init__(
  608. self,
  609. num_features: int,
  610. eps: float = 1e-5,
  611. momentum: float | None = 0.1,
  612. affine: bool = True,
  613. track_running_stats: bool = True,
  614. process_group: Any | None = None,
  615. device=None,
  616. dtype=None,
  617. ) -> None:
  618. factory_kwargs = {"device": device, "dtype": dtype}
  619. super().__init__(
  620. num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
  621. )
  622. self.process_group = process_group
  623. def _check_input_dim(self, input) -> None:
  624. if input.dim() < 2:
  625. raise ValueError(f"expected at least 2D input (got {input.dim()}D input)")
  626. def _check_non_zero_input_channels(self, input) -> None:
  627. if input.size(1) == 0:
  628. raise ValueError(
  629. "SyncBatchNorm number of input channels should be non-zero"
  630. )
  631. def forward(self, input: Tensor) -> Tensor:
  632. """
  633. Runs the forward pass.
  634. """
  635. self._check_input_dim(input)
  636. self._check_non_zero_input_channels(input)
  637. # exponential_average_factor is set to self.momentum
  638. # (when it is available) only so that it gets updated
  639. # in ONNX graph when this node is exported to ONNX.
  640. if self.momentum is None:
  641. exponential_average_factor = 0.0
  642. else:
  643. exponential_average_factor = self.momentum
  644. if self.training and self.track_running_stats:
  645. if self.num_batches_tracked is None:
  646. raise AssertionError("num_batches_tracked must not be None")
  647. self.num_batches_tracked.add_(1)
  648. if self.momentum is None: # use cumulative moving average
  649. exponential_average_factor = 1.0 / self.num_batches_tracked.item()
  650. else: # use exponential moving average
  651. exponential_average_factor = self.momentum
  652. r"""
  653. Decide whether the mini-batch stats should be used for normalization rather than the buffers.
  654. Mini-batch stats are used in training mode, and in eval mode when buffers are None.
  655. """
  656. if self.training:
  657. bn_training = True
  658. else:
  659. bn_training = (self.running_mean is None) and (self.running_var is None)
  660. r"""
  661. Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
  662. passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
  663. used for normalization (i.e. in eval mode when buffers are not None).
  664. """
  665. # If buffers are not to be tracked, ensure that they won't be updated
  666. running_mean = (
  667. self.running_mean if not self.training or self.track_running_stats else None
  668. )
  669. running_var = (
  670. self.running_var if not self.training or self.track_running_stats else None
  671. )
  672. # Don't sync batchnorm stats in inference mode (model.eval()).
  673. need_sync = (
  674. bn_training
  675. and self.training
  676. and torch.distributed.is_available()
  677. and torch.distributed.is_initialized()
  678. )
  679. if need_sync:
  680. # currently only GPU/PrivateUse1 input is supported
  681. if input.device.type not in [
  682. "cuda",
  683. "hpu",
  684. "xpu",
  685. torch._C._get_privateuse1_backend_name(),
  686. ]:
  687. raise ValueError(
  688. "SyncBatchNorm expected input tensor to be on GPU or XPU or "
  689. f"{torch._C._get_privateuse1_backend_name()}"
  690. )
  691. process_group = torch.distributed.group.WORLD
  692. if self.process_group:
  693. process_group = self.process_group
  694. world_size = torch.distributed.get_world_size(process_group)
  695. need_sync = world_size > 1
  696. # fallback to framework BN when synchronization is not necessary
  697. if not need_sync:
  698. return F.batch_norm(
  699. input,
  700. running_mean,
  701. running_var,
  702. self.weight,
  703. self.bias,
  704. bn_training,
  705. exponential_average_factor,
  706. self.eps,
  707. )
  708. else:
  709. if not bn_training:
  710. raise AssertionError("bn_training must be True")
  711. return sync_batch_norm.apply(
  712. input,
  713. self.weight,
  714. self.bias,
  715. running_mean,
  716. running_var,
  717. self.eps,
  718. exponential_average_factor,
  719. process_group, # type: ignore[possibly-undefined]
  720. world_size, # type: ignore[possibly-undefined]
  721. )
  722. @classmethod
  723. def convert_sync_batchnorm(cls, module, process_group=None):
  724. r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers.
  725. Args:
  726. module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
  727. process_group (optional): process group to scope synchronization,
  728. default is the whole world
  729. Returns:
  730. The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
  731. layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
  732. a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
  733. instead.
  734. Example::
  735. >>> # Network with nn.BatchNorm layer
  736. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  737. >>> module = torch.nn.Sequential(
  738. >>> torch.nn.Linear(20, 100),
  739. >>> torch.nn.BatchNorm1d(100),
  740. >>> ).cuda()
  741. >>> # creating process group (optional)
  742. >>> # ranks is a list of int identifying rank ids.
  743. >>> ranks = list(range(8))
  744. >>> r1, r2 = ranks[:4], ranks[4:]
  745. >>> # Note: every rank calls into new_group for every
  746. >>> # process group created, even if that rank is not
  747. >>> # part of the group.
  748. >>> # xdoctest: +SKIP("distributed")
  749. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
  750. >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
  751. >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
  752. """
  753. module_output = module
  754. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  755. module_output = torch.nn.SyncBatchNorm(
  756. module.num_features,
  757. module.eps,
  758. module.momentum,
  759. module.affine,
  760. module.track_running_stats,
  761. process_group,
  762. )
  763. if module.affine:
  764. with torch.no_grad():
  765. module_output.weight = module.weight
  766. module_output.bias = module.bias
  767. module_output.running_mean = module.running_mean
  768. module_output.running_var = module.running_var
  769. module_output.num_batches_tracked = module.num_batches_tracked
  770. module_output.training = module.training
  771. if hasattr(module, "qconfig"):
  772. module_output.qconfig = module.qconfig
  773. for name, child in module.named_children():
  774. module_output.add_module(
  775. name, cls.convert_sync_batchnorm(child, process_group)
  776. )
  777. del module
  778. return module_output