normalization.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. # mypy: allow-untyped-defs
  2. import numbers
  3. from typing import Union
  4. import torch
  5. from torch import Size, Tensor
  6. from torch.nn import functional as F, init
  7. from torch.nn.parameter import Parameter
  8. from ._functions import CrossMapLRN2d as _cross_map_lrn2d
  9. from .module import Module
  10. __all__ = ["LocalResponseNorm", "CrossMapLRN2d", "LayerNorm", "GroupNorm", "RMSNorm"]
  11. class LocalResponseNorm(Module):
  12. r"""Applies local response normalization over an input signal.
  13. The input signal is composed of several input planes, where channels occupy the second dimension.
  14. Applies normalization across channels.
  15. .. math::
  16. b_{c} = a_{c}\left(k + \frac{\alpha}{n}
  17. \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
  18. Args:
  19. size: amount of neighbouring channels used for normalization
  20. alpha: multiplicative factor. Default: 0.0001
  21. beta: exponent. Default: 0.75
  22. k: additive factor. Default: 1
  23. Shape:
  24. - Input: :math:`(N, C, *)`
  25. - Output: :math:`(N, C, *)` (same shape as input)
  26. Examples::
  27. >>> lrn = nn.LocalResponseNorm(2)
  28. >>> signal_2d = torch.randn(32, 5, 24, 24)
  29. >>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7)
  30. >>> output_2d = lrn(signal_2d)
  31. >>> output_4d = lrn(signal_4d)
  32. """
  33. __constants__ = ["size", "alpha", "beta", "k"]
  34. size: int
  35. alpha: float
  36. beta: float
  37. k: float
  38. def __init__(
  39. self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0
  40. ) -> None:
  41. super().__init__()
  42. self.size = size
  43. self.alpha = alpha
  44. self.beta = beta
  45. self.k = k
  46. def forward(self, input: Tensor) -> Tensor:
  47. """
  48. Runs the forward pass.
  49. """
  50. return F.local_response_norm(input, self.size, self.alpha, self.beta, self.k)
  51. def extra_repr(self):
  52. """
  53. Return the extra representation of the module.
  54. """
  55. return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__)
  56. class CrossMapLRN2d(Module):
  57. size: int
  58. alpha: float
  59. beta: float
  60. k: float
  61. def __init__(
  62. self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1
  63. ) -> None:
  64. super().__init__()
  65. self.size = size
  66. self.alpha = alpha
  67. self.beta = beta
  68. self.k = k
  69. def forward(self, input: Tensor) -> Tensor:
  70. """
  71. Runs the forward pass.
  72. """
  73. return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta, self.k)
  74. def extra_repr(self) -> str:
  75. """
  76. Return the extra representation of the module.
  77. """
  78. return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__)
  79. _shape_t = Union[int, list[int], Size]
  80. class LayerNorm(Module):
  81. r"""Applies Layer Normalization over a mini-batch of inputs.
  82. This layer implements the operation as described in
  83. the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
  84. .. math::
  85. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  86. The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
  87. is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
  88. is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
  89. the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
  90. :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
  91. :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
  92. The variance is calculated via the biased estimator, equivalent to
  93. `torch.var(input, correction=0)`.
  94. .. note::
  95. Unlike Batch Normalization and Instance Normalization, which applies
  96. scalar scale and bias for each entire channel/plane with the
  97. :attr:`affine` option, Layer Normalization applies per-element scale and
  98. bias with :attr:`elementwise_affine`.
  99. This layer uses statistics computed from input data in both training and
  100. evaluation modes.
  101. Args:
  102. normalized_shape (int or list or torch.Size): input shape from an expected input
  103. of size
  104. .. math::
  105. [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
  106. \times \ldots \times \text{normalized\_shape}[-1]]
  107. If a single integer is used, it is treated as a singleton list, and this module will
  108. normalize over the last dimension which is expected to be of that specific size.
  109. eps: a value added to the denominator for numerical stability. Default: 1e-5
  110. elementwise_affine: a boolean value that when set to ``True``, this module
  111. has learnable per-element affine parameters initialized to ones (for weights)
  112. and zeros (for biases). Default: ``True``.
  113. bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
  114. :attr:`elementwise_affine` is ``True``). Default: ``True``.
  115. Attributes:
  116. weight: the learnable weights of the module of shape
  117. :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
  118. The values are initialized to 1.
  119. bias: the learnable bias of the module of shape
  120. :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
  121. The values are initialized to 0.
  122. Shape:
  123. - Input: :math:`(N, *)`
  124. - Output: :math:`(N, *)` (same shape as input)
  125. Examples::
  126. >>> # NLP Example
  127. >>> batch, sentence_length, embedding_dim = 20, 5, 10
  128. >>> embedding = torch.randn(batch, sentence_length, embedding_dim)
  129. >>> layer_norm = nn.LayerNorm(embedding_dim)
  130. >>> # Activate module
  131. >>> layer_norm(embedding)
  132. >>>
  133. >>> # Image Example
  134. >>> N, C, H, W = 20, 5, 10, 10
  135. >>> input = torch.randn(N, C, H, W)
  136. >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
  137. >>> # as shown in the image below
  138. >>> layer_norm = nn.LayerNorm([C, H, W])
  139. >>> output = layer_norm(input)
  140. .. image:: ../_static/img/nn/layer_norm.jpg
  141. :scale: 50 %
  142. """
  143. __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
  144. normalized_shape: tuple[int, ...]
  145. eps: float
  146. elementwise_affine: bool
  147. def __init__(
  148. self,
  149. normalized_shape: _shape_t,
  150. eps: float = 1e-5,
  151. elementwise_affine: bool = True,
  152. bias: bool = True,
  153. device=None,
  154. dtype=None,
  155. ) -> None:
  156. factory_kwargs = {"device": device, "dtype": dtype}
  157. super().__init__()
  158. if isinstance(normalized_shape, numbers.Integral):
  159. # mypy error: incompatible types in assignment
  160. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  161. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  162. self.eps = eps
  163. self.elementwise_affine = elementwise_affine
  164. if self.elementwise_affine:
  165. self.weight = Parameter(
  166. torch.empty(self.normalized_shape, **factory_kwargs)
  167. )
  168. if bias:
  169. self.bias = Parameter(
  170. torch.empty(self.normalized_shape, **factory_kwargs)
  171. )
  172. else:
  173. self.register_parameter("bias", None)
  174. else:
  175. self.register_parameter("weight", None)
  176. self.register_parameter("bias", None)
  177. self.reset_parameters()
  178. def reset_parameters(self) -> None:
  179. if self.elementwise_affine:
  180. init.ones_(self.weight)
  181. if self.bias is not None:
  182. init.zeros_(self.bias)
  183. def forward(self, input: Tensor) -> Tensor:
  184. return F.layer_norm(
  185. input, self.normalized_shape, self.weight, self.bias, self.eps
  186. )
  187. def extra_repr(self) -> str:
  188. return (
  189. "{normalized_shape}, eps={eps}, "
  190. "elementwise_affine={elementwise_affine}".format(**self.__dict__)
  191. )
  192. class GroupNorm(Module):
  193. r"""Applies Group Normalization over a mini-batch of inputs.
  194. This layer implements the operation as described in
  195. the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
  196. .. math::
  197. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  198. The input channels are separated into :attr:`num_groups` groups, each containing
  199. ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by
  200. :attr:`num_groups`. The mean and standard-deviation are calculated
  201. separately over each group. :math:`\gamma` and :math:`\beta` are learnable
  202. per-channel affine transform parameter vectors of size :attr:`num_channels` if
  203. :attr:`affine` is ``True``.
  204. The variance is calculated via the biased estimator, equivalent to
  205. `torch.var(input, correction=0)`.
  206. This layer uses statistics computed from input data in both training and
  207. evaluation modes.
  208. Args:
  209. num_groups (int): number of groups to separate the channels into
  210. num_channels (int): number of channels expected in input
  211. eps: a value added to the denominator for numerical stability. Default: 1e-5
  212. affine: a boolean value that when set to ``True``, this module
  213. has learnable per-channel affine parameters initialized to ones (for weights)
  214. and zeros (for biases). Default: ``True``.
  215. Shape:
  216. - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
  217. - Output: :math:`(N, C, *)` (same shape as input)
  218. Examples::
  219. >>> input = torch.randn(20, 6, 10, 10)
  220. >>> # Separate 6 channels into 3 groups
  221. >>> m = nn.GroupNorm(3, 6)
  222. >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
  223. >>> m = nn.GroupNorm(6, 6)
  224. >>> # Put all 6 channels into a single group (equivalent with LayerNorm)
  225. >>> m = nn.GroupNorm(1, 6)
  226. >>> # Activating the module
  227. >>> output = m(input)
  228. """
  229. __constants__ = ["num_groups", "num_channels", "eps", "affine"]
  230. num_groups: int
  231. num_channels: int
  232. eps: float
  233. affine: bool
  234. def __init__(
  235. self,
  236. num_groups: int,
  237. num_channels: int,
  238. eps: float = 1e-5,
  239. affine: bool = True,
  240. device=None,
  241. dtype=None,
  242. ) -> None:
  243. factory_kwargs = {"device": device, "dtype": dtype}
  244. super().__init__()
  245. if num_channels % num_groups != 0:
  246. raise ValueError(
  247. f"num_channels ({num_channels}) must be divisible by num_groups ({num_groups})"
  248. )
  249. self.num_groups = num_groups
  250. self.num_channels = num_channels
  251. self.eps = eps
  252. self.affine = affine
  253. if self.affine:
  254. self.weight = Parameter(torch.empty(num_channels, **factory_kwargs))
  255. self.bias = Parameter(torch.empty(num_channels, **factory_kwargs))
  256. else:
  257. self.register_parameter("weight", None)
  258. self.register_parameter("bias", None)
  259. self.reset_parameters()
  260. def reset_parameters(self) -> None:
  261. if self.affine:
  262. init.ones_(self.weight)
  263. init.zeros_(self.bias)
  264. def forward(self, input: Tensor) -> Tensor:
  265. return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
  266. def extra_repr(self) -> str:
  267. return "{num_groups}, {num_channels}, eps={eps}, affine={affine}".format(
  268. **self.__dict__
  269. )
  270. class RMSNorm(Module):
  271. r"""Applies Root Mean Square Layer Normalization over a mini-batch of inputs.
  272. This layer implements the operation as described in
  273. the paper `Root Mean Square Layer Normalization <https://arxiv.org/pdf/1910.07467.pdf>`__
  274. .. math::
  275. y_i = \frac{x_i}{\mathrm{RMS}(x)} * \gamma_i, \quad
  276. \text{where} \quad \text{RMS}(x) = \sqrt{\epsilon + \frac{1}{n} \sum_{i=1}^{n} x_i^2}
  277. The RMS is taken over the last ``D`` dimensions, where ``D``
  278. is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
  279. is ``(3, 5)`` (a 2-dimensional shape), the RMS is computed over
  280. the last 2 dimensions of the input.
  281. Args:
  282. normalized_shape (int or list or torch.Size): input shape from an expected input
  283. of size
  284. .. math::
  285. [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
  286. \times \ldots \times \text{normalized\_shape}[-1]]
  287. If a single integer is used, it is treated as a singleton list, and this module will
  288. normalize over the last dimension which is expected to be of that specific size.
  289. eps: a value added to the denominator for numerical stability. If not specified,
  290. uses the machine epsilon of the computation (opmath) type: fp16/bf16 and
  291. fp32 inputs use ``torch.finfo(torch.float32).eps``, while fp64 inputs use
  292. ``torch.finfo(torch.float64).eps``.
  293. elementwise_affine: a boolean value that when set to ``True``, this module
  294. has learnable per-element affine parameters initialized to ones (for weights). Default: ``True``.
  295. Shape:
  296. - Input: :math:`(N, *)`
  297. - Output: :math:`(N, *)` (same shape as input)
  298. Examples::
  299. >>> rms_norm = nn.RMSNorm([2, 3])
  300. >>> input = torch.randn(2, 2, 3)
  301. >>> rms_norm(input)
  302. """
  303. __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
  304. normalized_shape: tuple[int, ...]
  305. eps: float | None
  306. elementwise_affine: bool
  307. def __init__(
  308. self,
  309. normalized_shape: _shape_t,
  310. eps: float | None = None,
  311. elementwise_affine: bool = True,
  312. device=None,
  313. dtype=None,
  314. ) -> None:
  315. factory_kwargs = {"device": device, "dtype": dtype}
  316. super().__init__()
  317. if isinstance(normalized_shape, numbers.Integral):
  318. # mypy error: incompatible types in assignment
  319. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  320. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  321. self.eps = eps
  322. self.elementwise_affine = elementwise_affine
  323. if self.elementwise_affine:
  324. self.weight = Parameter(
  325. torch.empty(self.normalized_shape, **factory_kwargs)
  326. )
  327. else:
  328. self.register_parameter("weight", None)
  329. self.reset_parameters()
  330. def reset_parameters(self) -> None:
  331. """
  332. Resets parameters based on their initialization used in __init__.
  333. """
  334. if self.elementwise_affine:
  335. init.ones_(self.weight)
  336. def forward(self, x: torch.Tensor) -> torch.Tensor:
  337. """
  338. Runs the forward pass.
  339. """
  340. return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
  341. def extra_repr(self) -> str:
  342. """
  343. Return the extra representation of the module.
  344. """
  345. return (
  346. "{normalized_shape}, eps={eps}, "
  347. "elementwise_affine={elementwise_affine}".format(**self.__dict__)
  348. )
  349. # TODO: ContrastiveNorm2d
  350. # TODO: DivisiveNorm2d
  351. # TODO: SubtractiveNorm2d