instancenorm.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. import torch.nn.functional as F
  4. from torch import Tensor
  5. from .batchnorm import _LazyNormBase, _NormBase
  6. __all__ = [
  7. "InstanceNorm1d",
  8. "InstanceNorm2d",
  9. "InstanceNorm3d",
  10. "LazyInstanceNorm1d",
  11. "LazyInstanceNorm2d",
  12. "LazyInstanceNorm3d",
  13. ]
  14. class _InstanceNorm(_NormBase):
  15. def __init__(
  16. self,
  17. num_features: int,
  18. eps: float = 1e-5,
  19. momentum: float = 0.1,
  20. affine: bool = False,
  21. track_running_stats: bool = False,
  22. device=None,
  23. dtype=None,
  24. ) -> None:
  25. factory_kwargs = {"device": device, "dtype": dtype}
  26. super().__init__(
  27. num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
  28. )
  29. def _check_input_dim(self, input):
  30. raise NotImplementedError
  31. def _get_no_batch_dim(self):
  32. raise NotImplementedError
  33. def _handle_no_batch_input(self, input):
  34. return self._apply_instance_norm(input.unsqueeze(0)).squeeze(0)
  35. def _apply_instance_norm(self, input):
  36. return F.instance_norm(
  37. input,
  38. self.running_mean,
  39. self.running_var,
  40. self.weight,
  41. self.bias,
  42. self.training or not self.track_running_stats,
  43. self.momentum if self.momentum is not None else 0.0,
  44. self.eps,
  45. )
  46. def _load_from_state_dict(
  47. self,
  48. state_dict,
  49. prefix,
  50. local_metadata,
  51. strict,
  52. missing_keys,
  53. unexpected_keys,
  54. error_msgs,
  55. ) -> None:
  56. version = local_metadata.get("version", None)
  57. # at version 1: removed running_mean and running_var when
  58. # track_running_stats=False (default)
  59. if version is None and not self.track_running_stats:
  60. running_stats_keys = []
  61. for name in ("running_mean", "running_var"):
  62. key = prefix + name
  63. if key in state_dict:
  64. running_stats_keys.append(key)
  65. if len(running_stats_keys) > 0:
  66. error_msgs.append(
  67. "Unexpected running stats buffer(s) {names} for {klass} "
  68. "with track_running_stats=False. If state_dict is a "
  69. "checkpoint saved before 0.4.0, this may be expected "
  70. "because {klass} does not track running stats by default "
  71. "since 0.4.0. Please remove these keys from state_dict. If "
  72. "the running stats are actually needed, instead set "
  73. "track_running_stats=True in {klass} to enable them. See "
  74. "the documentation of {klass} for details.".format(
  75. names=" and ".join(f'"{k}"' for k in running_stats_keys),
  76. klass=self.__class__.__name__,
  77. )
  78. )
  79. for key in running_stats_keys:
  80. state_dict.pop(key)
  81. super()._load_from_state_dict(
  82. state_dict,
  83. prefix,
  84. local_metadata,
  85. strict,
  86. missing_keys,
  87. unexpected_keys,
  88. error_msgs,
  89. )
  90. def forward(self, input: Tensor) -> Tensor:
  91. self._check_input_dim(input)
  92. feature_dim = input.dim() - self._get_no_batch_dim()
  93. if input.size(feature_dim) != self.num_features:
  94. if self.affine:
  95. raise ValueError(
  96. f"expected input's size at dim={feature_dim} to match num_features"
  97. f" ({self.num_features}), but got: {input.size(feature_dim)}."
  98. )
  99. else:
  100. warnings.warn(
  101. f"input's size at dim={feature_dim} does not match num_features. "
  102. "You can silence this warning by not passing in num_features, "
  103. "which is not used because affine=False",
  104. stacklevel=2,
  105. )
  106. if input.dim() == self._get_no_batch_dim():
  107. return self._handle_no_batch_input(input)
  108. return self._apply_instance_norm(input)
  109. class InstanceNorm1d(_InstanceNorm):
  110. r"""Applies Instance Normalization.
  111. This operation applies Instance Normalization
  112. over a 2D (unbatched) or 3D (batched) input as described in the paper
  113. `Instance Normalization: The Missing Ingredient for Fast Stylization
  114. <https://arxiv.org/abs/1607.08022>`__.
  115. .. math::
  116. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  117. The mean and standard-deviation are calculated per-dimension separately
  118. for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  119. of size `C` (where `C` is the number of features or channels of the input) if :attr:`affine` is ``True``.
  120. The variance is calculated via the biased estimator, equivalent to
  121. `torch.var(input, correction=0)`.
  122. By default, this layer uses instance statistics computed from input data in
  123. both training and evaluation modes.
  124. If :attr:`track_running_stats` is set to ``True``, during training this
  125. layer keeps running estimates of its computed mean and variance, which are
  126. then used for normalization during evaluation. The running estimates are
  127. kept with a default :attr:`momentum` of 0.1.
  128. .. note::
  129. This :attr:`momentum` argument is different from one used in optimizer
  130. classes and the conventional notion of momentum. Mathematically, the
  131. update rule for running statistics here is
  132. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  133. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  134. new observed value.
  135. .. note::
  136. :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but
  137. have some subtle differences. :class:`InstanceNorm1d` is applied
  138. on each channel of channeled data like multidimensional time series, but
  139. :class:`LayerNorm` is usually applied on entire sample and often in NLP
  140. tasks. Additionally, :class:`LayerNorm` applies elementwise affine
  141. transform, while :class:`InstanceNorm1d` usually don't apply affine
  142. transform.
  143. Args:
  144. num_features: number of features or channels :math:`C` of the input
  145. eps: a value added to the denominator for numerical stability. Default: 1e-5
  146. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  147. affine: a boolean value that when set to ``True``, this module has
  148. learnable affine parameters, initialized the same way as done for batch normalization.
  149. Default: ``False``.
  150. track_running_stats: a boolean value that when set to ``True``, this
  151. module tracks the running mean and variance, and when set to ``False``,
  152. this module does not track such statistics and always uses batch
  153. statistics in both training and eval modes. Default: ``False``
  154. Shape:
  155. - Input: :math:`(N, C, L)` or :math:`(C, L)`
  156. - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input)
  157. Examples::
  158. >>> # Without Learnable Parameters
  159. >>> m = nn.InstanceNorm1d(100)
  160. >>> # With Learnable Parameters
  161. >>> m = nn.InstanceNorm1d(100, affine=True)
  162. >>> input = torch.randn(20, 100, 40)
  163. >>> output = m(input)
  164. """
  165. def _get_no_batch_dim(self) -> int:
  166. return 2
  167. def _check_input_dim(self, input) -> None:
  168. if input.dim() not in (2, 3):
  169. raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
  170. class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm):
  171. r"""A :class:`torch.nn.InstanceNorm1d` module with lazy initialization of the ``num_features`` argument.
  172. The ``num_features`` argument of the :class:`InstanceNorm1d` is inferred from the ``input.size(1)``.
  173. The attributes that will be lazily initialized are `weight`, `bias`, `running_mean` and `running_var`.
  174. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  175. on lazy modules and their limitations.
  176. Args:
  177. num_features: :math:`C` from an expected input of size
  178. :math:`(N, C, L)` or :math:`(C, L)`
  179. eps: a value added to the denominator for numerical stability. Default: 1e-5
  180. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  181. affine: a boolean value that when set to ``True``, this module has
  182. learnable affine parameters, initialized the same way as done for batch normalization.
  183. Default: ``False``.
  184. track_running_stats: a boolean value that when set to ``True``, this
  185. module tracks the running mean and variance, and when set to ``False``,
  186. this module does not track such statistics and always uses batch
  187. statistics in both training and eval modes. Default: ``False``
  188. Shape:
  189. - Input: :math:`(N, C, L)` or :math:`(C, L)`
  190. - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input)
  191. """
  192. cls_to_become = InstanceNorm1d # type: ignore[assignment]
  193. def _get_no_batch_dim(self) -> int:
  194. return 2
  195. def _check_input_dim(self, input) -> None:
  196. if input.dim() not in (2, 3):
  197. raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
  198. class InstanceNorm2d(_InstanceNorm):
  199. r"""Applies Instance Normalization.
  200. This operation applies Instance Normalization
  201. over a 4D input (a mini-batch of 2D inputs
  202. with additional channel dimension) as described in the paper
  203. `Instance Normalization: The Missing Ingredient for Fast Stylization
  204. <https://arxiv.org/abs/1607.08022>`__.
  205. .. math::
  206. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  207. The mean and standard-deviation are calculated per-dimension separately
  208. for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  209. of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
  210. The standard-deviation is calculated via the biased estimator, equivalent to
  211. `torch.var(input, correction=0)`.
  212. By default, this layer uses instance statistics computed from input data in
  213. both training and evaluation modes.
  214. If :attr:`track_running_stats` is set to ``True``, during training this
  215. layer keeps running estimates of its computed mean and variance, which are
  216. then used for normalization during evaluation. The running estimates are
  217. kept with a default :attr:`momentum` of 0.1.
  218. .. note::
  219. This :attr:`momentum` argument is different from one used in optimizer
  220. classes and the conventional notion of momentum. Mathematically, the
  221. update rule for running statistics here is
  222. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  223. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  224. new observed value.
  225. .. note::
  226. :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but
  227. have some subtle differences. :class:`InstanceNorm2d` is applied
  228. on each channel of channeled data like RGB images, but
  229. :class:`LayerNorm` is usually applied on entire sample and often in NLP
  230. tasks. Additionally, :class:`LayerNorm` applies elementwise affine
  231. transform, while :class:`InstanceNorm2d` usually don't apply affine
  232. transform.
  233. Args:
  234. num_features: :math:`C` from an expected input of size
  235. :math:`(N, C, H, W)` or :math:`(C, H, W)`
  236. eps: a value added to the denominator for numerical stability. Default: 1e-5
  237. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  238. affine: a boolean value that when set to ``True``, this module has
  239. learnable affine parameters, initialized the same way as done for batch normalization.
  240. Default: ``False``.
  241. track_running_stats: a boolean value that when set to ``True``, this
  242. module tracks the running mean and variance, and when set to ``False``,
  243. this module does not track such statistics and always uses batch
  244. statistics in both training and eval modes. Default: ``False``
  245. Shape:
  246. - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
  247. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
  248. Examples::
  249. >>> # Without Learnable Parameters
  250. >>> m = nn.InstanceNorm2d(100)
  251. >>> # With Learnable Parameters
  252. >>> m = nn.InstanceNorm2d(100, affine=True)
  253. >>> input = torch.randn(20, 100, 35, 45)
  254. >>> output = m(input)
  255. """
  256. def _get_no_batch_dim(self) -> int:
  257. return 3
  258. def _check_input_dim(self, input) -> None:
  259. if input.dim() not in (3, 4):
  260. raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)")
  261. class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm):
  262. r"""A :class:`torch.nn.InstanceNorm2d` module with lazy initialization of the ``num_features`` argument.
  263. The ``num_features`` argument of the :class:`InstanceNorm2d` is inferred from the ``input.size(1)``.
  264. The attributes that will be lazily initialized are `weight`, `bias`,
  265. `running_mean` and `running_var`.
  266. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  267. on lazy modules and their limitations.
  268. Args:
  269. num_features: :math:`C` from an expected input of size
  270. :math:`(N, C, H, W)` or :math:`(C, H, W)`
  271. eps: a value added to the denominator for numerical stability. Default: 1e-5
  272. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  273. affine: a boolean value that when set to ``True``, this module has
  274. learnable affine parameters, initialized the same way as done for batch normalization.
  275. Default: ``False``.
  276. track_running_stats: a boolean value that when set to ``True``, this
  277. module tracks the running mean and variance, and when set to ``False``,
  278. this module does not track such statistics and always uses batch
  279. statistics in both training and eval modes. Default: ``False``
  280. Shape:
  281. - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
  282. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
  283. """
  284. cls_to_become = InstanceNorm2d # type: ignore[assignment]
  285. def _get_no_batch_dim(self) -> int:
  286. return 3
  287. def _check_input_dim(self, input) -> None:
  288. if input.dim() not in (3, 4):
  289. raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)")
  290. class InstanceNorm3d(_InstanceNorm):
  291. r"""Applies Instance Normalization.
  292. This operation applies Instance Normalization
  293. over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper
  294. `Instance Normalization: The Missing Ingredient for Fast Stylization
  295. <https://arxiv.org/abs/1607.08022>`__.
  296. .. math::
  297. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  298. The mean and standard-deviation are calculated per-dimension separately
  299. for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  300. of size C (where C is the input size) if :attr:`affine` is ``True``.
  301. The standard-deviation is calculated via the biased estimator, equivalent to
  302. `torch.var(input, correction=0)`.
  303. By default, this layer uses instance statistics computed from input data in
  304. both training and evaluation modes.
  305. If :attr:`track_running_stats` is set to ``True``, during training this
  306. layer keeps running estimates of its computed mean and variance, which are
  307. then used for normalization during evaluation. The running estimates are
  308. kept with a default :attr:`momentum` of 0.1.
  309. .. note::
  310. This :attr:`momentum` argument is different from one used in optimizer
  311. classes and the conventional notion of momentum. Mathematically, the
  312. update rule for running statistics here is
  313. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  314. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  315. new observed value.
  316. .. note::
  317. :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but
  318. have some subtle differences. :class:`InstanceNorm3d` is applied
  319. on each channel of channeled data like 3D models with RGB color, but
  320. :class:`LayerNorm` is usually applied on entire sample and often in NLP
  321. tasks. Additionally, :class:`LayerNorm` applies elementwise affine
  322. transform, while :class:`InstanceNorm3d` usually don't apply affine
  323. transform.
  324. Args:
  325. num_features: :math:`C` from an expected input of size
  326. :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  327. eps: a value added to the denominator for numerical stability. Default: 1e-5
  328. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  329. affine: a boolean value that when set to ``True``, this module has
  330. learnable affine parameters, initialized the same way as done for batch normalization.
  331. Default: ``False``.
  332. track_running_stats: a boolean value that when set to ``True``, this
  333. module tracks the running mean and variance, and when set to ``False``,
  334. this module does not track such statistics and always uses batch
  335. statistics in both training and eval modes. Default: ``False``
  336. Shape:
  337. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  338. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input)
  339. Examples::
  340. >>> # Without Learnable Parameters
  341. >>> m = nn.InstanceNorm3d(100)
  342. >>> # With Learnable Parameters
  343. >>> m = nn.InstanceNorm3d(100, affine=True)
  344. >>> input = torch.randn(20, 100, 35, 45, 10)
  345. >>> output = m(input)
  346. """
  347. def _get_no_batch_dim(self) -> int:
  348. return 4
  349. def _check_input_dim(self, input) -> None:
  350. if input.dim() not in (4, 5):
  351. raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)")
  352. class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm):
  353. r"""A :class:`torch.nn.InstanceNorm3d` module with lazy initialization of the ``num_features`` argument.
  354. The ``num_features`` argument of the :class:`InstanceNorm3d` is inferred from the ``input.size(1)``.
  355. The attributes that will be lazily initialized are `weight`, `bias`,
  356. `running_mean` and `running_var`.
  357. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  358. on lazy modules and their limitations.
  359. Args:
  360. num_features: :math:`C` from an expected input of size
  361. :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  362. eps: a value added to the denominator for numerical stability. Default: 1e-5
  363. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  364. affine: a boolean value that when set to ``True``, this module has
  365. learnable affine parameters, initialized the same way as done for batch normalization.
  366. Default: ``False``.
  367. track_running_stats: a boolean value that when set to ``True``, this
  368. module tracks the running mean and variance, and when set to ``False``,
  369. this module does not track such statistics and always uses batch
  370. statistics in both training and eval modes. Default: ``False``
  371. Shape:
  372. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  373. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input)
  374. """
  375. cls_to_become = InstanceNorm3d # type: ignore[assignment]
  376. def _get_no_batch_dim(self) -> int:
  377. return 4
  378. def _check_input_dim(self, input) -> None:
  379. if input.dim() not in (4, 5):
  380. raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)")